diff options
86 files changed, 3060 insertions, 2032 deletions
@@ -152,7 +152,7 @@ nogo: ## Surfaces all nogo findings. .PHONY: nogo gazelle: ## Runs gazelle to update WORKSPACE. - @$(call run,//:gazelle update-repos -from_file=go.mod -prune) + @$(call run,//:gazelle,update-repos -from_file=go.mod -prune) .PHONY: gazelle ## diff --git a/images/basic/hostoverlaytest/Dockerfile b/images/basic/hostoverlaytest/Dockerfile deleted file mode 100644 index 6cef1a542..000000000 --- a/images/basic/hostoverlaytest/Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY . . - -RUN apt-get update && apt-get install -y gcc -RUN gcc -O2 -o test_copy_up test_copy_up.c -RUN gcc -O2 -o test_rewinddir test_rewinddir.c diff --git a/images/basic/integrationtest/Dockerfile b/images/basic/integrationtest/Dockerfile new file mode 100644 index 000000000..e80e17527 --- /dev/null +++ b/images/basic/integrationtest/Dockerfile @@ -0,0 +1,7 @@ +FROM ubuntu:bionic + +WORKDIR /root +COPY . . +RUN chmod +x *.sh + +RUN apt-get update && apt-get install -y gcc iputils-ping iproute2 diff --git a/images/basic/hostoverlaytest/copy_up_testfile.txt b/images/basic/integrationtest/copy_up_testfile.txt index e4188c841..e4188c841 100644 --- a/images/basic/hostoverlaytest/copy_up_testfile.txt +++ b/images/basic/integrationtest/copy_up_testfile.txt diff --git a/images/basic/linktest/link_test.c b/images/basic/integrationtest/link_test.c index 45ab00abe..45ab00abe 100644 --- a/images/basic/linktest/link_test.c +++ b/images/basic/integrationtest/link_test.c diff --git a/images/basic/ping4test/ping4.sh b/images/basic/integrationtest/ping4.sh index 2a343712a..2a343712a 100644 --- a/images/basic/ping4test/ping4.sh +++ b/images/basic/integrationtest/ping4.sh diff --git a/images/basic/ping6test/ping6.sh b/images/basic/integrationtest/ping6.sh index 4268951d0..4268951d0 100644 --- a/images/basic/ping6test/ping6.sh +++ b/images/basic/integrationtest/ping6.sh diff --git a/images/basic/hostoverlaytest/test_copy_up.c b/images/basic/integrationtest/test_copy_up.c index 010b261dc..010b261dc 100644 --- a/images/basic/hostoverlaytest/test_copy_up.c +++ b/images/basic/integrationtest/test_copy_up.c diff --git a/images/basic/hostoverlaytest/test_rewinddir.c b/images/basic/integrationtest/test_rewinddir.c index f1a4085e1..f1a4085e1 100644 --- a/images/basic/hostoverlaytest/test_rewinddir.c +++ b/images/basic/integrationtest/test_rewinddir.c diff --git a/images/basic/linktest/Dockerfile b/images/basic/linktest/Dockerfile deleted file mode 100644 index baebc9b76..000000000 --- a/images/basic/linktest/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY . . - -RUN apt-get update && apt-get install -y gcc -RUN gcc -O2 -o link_test link_test.c diff --git a/images/basic/ping4test/Dockerfile b/images/basic/ping4test/Dockerfile deleted file mode 100644 index 1536be376..000000000 --- a/images/basic/ping4test/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY ping4.sh . -RUN chmod +x ping4.sh - -RUN apt-get update && apt-get install -y iputils-ping diff --git a/images/basic/ping6test/Dockerfile b/images/basic/ping6test/Dockerfile deleted file mode 100644 index cb740bd60..000000000 --- a/images/basic/ping6test/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY ping6.sh . -RUN chmod +x ping6.sh - -RUN apt-get update && apt-get install -y iputils-ping iproute2 diff --git a/images/syzkaller/README.md b/images/syzkaller/README.md index 60aa61d12..1eac474f3 100644 --- a/images/syzkaller/README.md +++ b/images/syzkaller/README.md @@ -1,34 +1,25 @@ syzkaller is an unsupervised coverage-guided kernel fuzzer. -* [Github](https://github.com/google/syzkaller) -* [gVisor dashboard](https://syzkaller.appspot.com/gvisor) +* [Github](https://github.com/google/syzkaller) +* [gVisor dashboard](https://syzkaller.appspot.com/gvisor) # How to run syzkaller. -* Build the syzkaller docker image -``` -make load-syzkaller -``` -* Build runsc and place it in /tmp/syzkaller. -``` - make RUNTIME_DIR=/tmp/syzkaller refresh -``` -* Copy the syzkaller config in /tmp/syzkaller -``` -cp images/syzkaller/default-gvisor-config.cfg /tmp/syzkaller/syzkaller.cfg -``` -* Run syzkaller -``` -docker run --privileged -it --rm -v /tmp/syzkaller:/tmp/syzkaller gvisor.dev/images/syzkaller:latest -``` +* Build the syzkaller docker image `make load-syzkaller` +* Build runsc and place it in /tmp/syzkaller. `make RUNTIME_DIR=/tmp/syzkaller + refresh` +* Copy the syzkaller config in /tmp/syzkaller `cp + images/syzkaller/default-gvisor-config.cfg /tmp/syzkaller/syzkaller.cfg` +* Run syzkaller `docker run --privileged -it --rm -v + /tmp/syzkaller:/tmp/syzkaller gvisor.dev/images/syzkaller:latest` # How to run a syz repro. -* Repeate all steps except the last one from the previous section. +* Repeate all steps except the last one from the previous section. -* Save a syzkaller repro in /tmp/syzkaller/repro +* Save a syzkaller repro in /tmp/syzkaller/repro -* Run syz-repro -``` -docker run --privileged -it --rm -v /tmp/syzkaller:/tmp/syzkaller --entrypoint="" gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config /tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro -``` +* Run syz-repro `docker run --privileged -it --rm -v + /tmp/syzkaller:/tmp/syzkaller --entrypoint="" + gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config + /tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro` @@ -156,3 +156,4 @@ analyzers: internal: exclude: - pkg/sentry/fs/fdpipe/pipe_opener_test.go # False positive. + - pkg/tcpip/tests/integration/link_resolution_test.go # False positive. diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index cb33c37bd..185eee0bb 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -347,26 +347,57 @@ const SizeOfLinger = 8 // // +marshal type TCPInfo struct { - State uint8 - CaState uint8 + // State is the state of the connection. + State uint8 + + // CaState is the congestion control state. + CaState uint8 + + // Retransmits is the number of retransmissions triggered by RTO. Retransmits uint8 - Probes uint8 - Backoff uint8 - Options uint8 - // WindowScale is the combination of snd_wscale (first 4 bits) and rcv_wscale (second 4 bits) + + // Probes is the number of unanswered zero window probes. + Probes uint8 + + // BackOff indicates exponential backoff. + Backoff uint8 + + // Options indicates the options enabled for the connection. + Options uint8 + + // WindowScale is the combination of snd_wscale (first 4 bits) and + // rcv_wscale (second 4 bits) WindowScale uint8 - // DeliveryRateAppLimited is a boolean and only the first bit is meaningful. + + // DeliveryRateAppLimited is a boolean and only the first bit is + // meaningful. DeliveryRateAppLimited uint8 - RTO uint32 - ATO uint32 + // RTO is the retransmission timeout. + RTO uint32 + + // ATO is the acknowledgement timeout interval. + ATO uint32 + + // SndMss is the send maximum segment size. SndMss uint32 + + // RcvMss is the receive maximum segment size. RcvMss uint32 + // Unacked is the number of packets sent but not acknowledged. Unacked uint32 - Sacked uint32 - Lost uint32 + + // Sacked is the number of packets which are selectively acknowledged. + Sacked uint32 + + // Lost is the number of packets marked as lost. + Lost uint32 + + // Retrans is the number of retransmitted packets. Retrans uint32 + + // Fackets is not used and is always zero. Fackets uint32 // Times. @@ -385,48 +416,77 @@ type TCPInfo struct { Advmss uint32 Reordering uint32 - RcvRTT uint32 + // RcvRTT is the receiver round trip time. + RcvRTT uint32 + + // RcvSpace is the current buffer space available for receiving data. RcvSpace uint32 + // TotalRetrans is the total number of retransmits seen since the start + // of the connection. TotalRetrans uint32 - PacingRate uint64 + // PacingRate is the pacing rate in bytes per second. + PacingRate uint64 + + // MaxPacingRate is the maximum pacing rate. MaxPacingRate uint64 + // BytesAcked is RFC4898 tcpEStatsAppHCThruOctetsAcked. BytesAcked uint64 + // BytesReceived is RFC4898 tcpEStatsAppHCThruOctetsReceived. BytesReceived uint64 + // SegsOut is RFC4898 tcpEStatsPerfSegsOut. SegsOut uint32 + // SegsIn is RFC4898 tcpEStatsPerfSegsIn. SegsIn uint32 + // NotSentBytes is the amount of bytes in the write queue that are not + // yet sent. NotSentBytes uint32 - MinRTT uint32 + + // MinRTT is the minimum round trip time seen in the connection. + MinRTT uint32 + // DataSegsIn is RFC4898 tcpEStatsDataSegsIn. DataSegsIn uint32 + // DataSegsOut is RFC4898 tcpEStatsDataSegsOut. DataSegsOut uint32 + // DeliveryRate is the most recent delivery rate in bytes per second. DeliveryRate uint64 // BusyTime is the time in microseconds busy sending data. BusyTime uint64 + // RwndLimited is the time in microseconds limited by receive window. RwndLimited uint64 + // SndBufLimited is the time in microseconds limited by send buffer. SndBufLimited uint64 - Delievered uint32 - DelieveredCe uint32 + // Delivered is the total data packets delivered including retransmits. + Delivered uint32 + + // DeliveredCE is the total ECE marked data packets delivered including + // retransmits. + DeliveredCE uint32 // BytesSent is RFC4898 tcpEStatsPerfHCDataOctetsOut. BytesSent uint64 + // BytesRetrans is RFC4898 tcpEStatsPerfOctetsRetrans. BytesRetrans uint64 + // DSACKDups is RFC4898 tcpEStatsStackDSACKDups. DSACKDups uint32 - // ReordSeen is the number of reordering events seen. + + // ReordSeen is the number of reordering events seen since the start of + // the connection. ReordSeen uint32 } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 94f03af48..69693f263 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2666,9 +2666,9 @@ func (s *socketOpsCommon) dequeueErr() *tcpip.SockError { } // Update socket error to reflect ICMP errors in queue. - if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() { + if nextErr := so.PeekErr(); nextErr != nil && nextErr.Cause.Origin().IsICMPErr() { so.SetLastError(nextErr.Err) - } else if err.ErrOrigin.IsICMPErr() { + } else if err.Cause.Origin().IsICMPErr() { so.SetLastError(nil) } return err diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 97729dacc..cc535d794 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -81,10 +81,10 @@ func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { ee := linux.SockExtendedErr{ Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), - Origin: errOriginToLinux(sockErr.ErrOrigin), - Type: sockErr.ErrType, - Code: sockErr.ErrCode, - Info: sockErr.ErrInfo, + Origin: errOriginToLinux(sockErr.Cause.Origin()), + Type: sockErr.Cause.Type(), + Code: sockErr.Cause.Code(), + Info: sockErr.Cause.Info(), } switch sockErr.NetProto { diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index c9bcf9326..23aa0ad05 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "prependable.go", "view.go", + "view_unsafe.go", ], visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 91cc62cc8..b05e81526 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -239,6 +239,16 @@ func (vv *VectorisedView) Size() int { return vv.size } +// MemSize returns the estimation size of the vv in memory, including backing +// buffer data. +func (vv *VectorisedView) MemSize() int { + var size int + for _, v := range vv.views { + size += cap(v) + } + return size + cap(vv.views)*viewStructSize + vectorisedViewStructSize +} + // ToView returns a single view containing the content of the vectorised view. // // If the vectorised view contains a single view, that view will be returned diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index e7f7cc9f1..78b2faa26 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -20,6 +20,7 @@ import ( "io" "reflect" "testing" + "unsafe" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -578,3 +579,15 @@ func TestAppendView(t *testing.T) { } } } + +func TestMemSize(t *testing.T) { + const perViewCap = 128 + views := make([]buffer.View, 2, 32) + views[0] = make(buffer.View, 10, perViewCap) + views[1] = make(buffer.View, 20, perViewCap) + vv := buffer.NewVectorisedView(30, views) + want := int(unsafe.Sizeof(vv)) + cap(views)*int(unsafe.Sizeof(views)) + 2*perViewCap + if got := vv.MemSize(); got != want { + t.Errorf("vv.MemSize() = %d, want %d", got, want) + } +} diff --git a/pkg/tcpip/buffer/view_unsafe.go b/pkg/tcpip/buffer/view_unsafe.go new file mode 100644 index 000000000..75ccd40f8 --- /dev/null +++ b/pkg/tcpip/buffer/view_unsafe.go @@ -0,0 +1,22 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import "unsafe" + +const ( + vectorisedViewStructSize = int(unsafe.Sizeof(VectorisedView{})) + viewStructSize = int(unsafe.Sizeof(View{})) +) diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 5f9b8e9e2..f840a4322 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -16,7 +16,6 @@ package header import ( "encoding/binary" - "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -208,16 +207,3 @@ func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { return ^xsum } - -// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when -// a packet having a `net` header causing an ICMP error. -func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin { - switch net { - case IPv4ProtocolNumber: - return tcpip.SockExtErrorOriginICMP - case IPv6ProtocolNumber: - return tcpip.SockExtErrorOriginICMP6 - default: - panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net)) - } -} diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index c7ab876bf..933845269 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,7 +10,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7838cc753..0d7fadc31 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -22,7 +22,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,6 +34,8 @@ const ( ProtocolNumber = header.ARPProtocolNumber ) +var _ stack.LinkAddressResolver = (*endpoint)(nil) + // ARP endpoints need to implement stack.NetworkEndpoint because the stack // considers the layer above the link-layer a network layer; the only // facility provided by the stack to deliver packets to a layer above @@ -49,10 +50,8 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 - nic stack.NetworkInterface - linkAddrCache stack.LinkAddressCache - nud stack.NUDHandler - stats sharedStats + nic stack.NetworkInterface + stats sharedStats } func (e *endpoint) Enable() tcpip.Error { @@ -101,9 +100,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.protocol.forgetEndpoint(e.nic.ID()) -} +func (*endpoint) Close() {} func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} @@ -151,10 +148,12 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - if e.nud == nil { - e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) - } else { - e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + switch err := e.nic.HandleNeighborProbe(header.IPv4ProtocolNumber, remoteAddr, remoteLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ARP but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) } respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -195,14 +194,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - if e.nud == nil { - e.linkAddrCache.AddLinkAddress(addr, linkAddr) - return - } - // The solicited, override, and isRouter flags are not available for ARP; // they are only available for IPv6 Neighbor Advertisements. - e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ + switch err := e.nic.HandleNeighborConfirmation(header.IPv4ProtocolNumber, addr, linkAddr, stack.ReachabilityConfirmationFlags{ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies // are handled equivalently to a solicited Neighbor Advertisement. Solicited: true, @@ -211,7 +205,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { Override: false, // ARP does not distinguish between router and non-router hosts. IsRouter: false, - }) + }); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ARP but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err)) + } } } @@ -221,19 +221,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { } var _ stack.NetworkProtocol = (*protocol)(nil) -var _ stack.LinkAddressResolver = (*protocol)(nil) // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack - - mu struct { - sync.RWMutex - - // eps is keyed by NICID to allow protocol methods to retrieve the correct - // endpoint depending on the NIC. - eps map[tcpip.NICID]*endpoint - } } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -244,12 +235,10 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" } -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ - protocol: p, - nic: nic, - linkAddrCache: linkAddrCache, - nud: nud, + protocol: p, + nic: nic, } tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) @@ -257,43 +246,26 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L stackStats := p.stack.Stats() e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP) - p.mu.Lock() - p.mu.eps[nic.ID()] = e - p.mu.Unlock() - return e } -func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.mu.eps, nicID) -} - // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. -func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { - nicID := nic.ID() - - p.mu.Lock() - netEP, ok := p.mu.eps[nicID] - p.mu.Unlock() - if !ok { - return &tcpip.ErrNotConnected{} - } +func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + nicID := e.nic.ID() - stats := netEP.stats.arp + stats := e.stats.arp if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } if len(localAddr) == 0 { - addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) if !ok { return &tcpip.ErrUnknownNICID{} } @@ -304,13 +276,13 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } localAddr = addr.Address - } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { stats.outgoingRequestBadLocalAddressErrors.Increment() return &tcpip.ErrBadLocalAddress{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) pkt.NetworkProtocolNumber = ProtocolNumber @@ -318,14 +290,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot h.SetOp(header.ARPRequest) // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a // link address. - _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress()) if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } - if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stats.outgoingRequestsDropped.Increment() return err } @@ -334,7 +306,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. -func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return header.EthernetBroadcastAddress, true } @@ -369,9 +341,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return &protocol{ stack: s, - mu: struct { - sync.RWMutex - eps map[tcpip.NICID]*endpoint - }{eps: make(map[tcpip.NICID]*endpoint)}, } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index b0f07aa44..24357e15d 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -491,9 +491,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { t.Fatal(err) } - neighbors, err := c.s.Neighbors(nicID) + neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) if err != nil { - t.Fatalf("c.s.Neighbors(%d): %s", nicID, err) + t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) @@ -530,52 +530,19 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } -var _ stack.NetworkInterface = (*testInterface)(nil) +var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil) -type testInterface struct { +type testLinkEndpoint struct { stack.LinkEndpoint - nicID tcpip.NICID - writeErr tcpip.Error } -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (*testInterface) Enabled() bool { - return true -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) -} - -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) -} - -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if t.writeErr != nil { return t.writeErr } - var r stack.RouteInfo - r.NetProto = protocol - r.RemoteLinkAddress = remoteLinkAddr return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } @@ -709,32 +676,31 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, }) - p := s.NetworkProtocolInstance(arp.ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, linkEP); err != nil { + if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } + ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err) + } + linkRes, ok := ep.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) + } + if len(test.nicAddr) != 0 { if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) } } - // We pass a test network interface to LinkAddressRequest with the same - // NIC ID and link endpoint used by the NIC we created earlier so that we - // can mock a link address request and observe the packets sent to the - // link endpoint even though the stack uses the real NIC to validate the - // local address. - iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr} - err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) + { + err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) + } } if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { @@ -782,19 +748,3 @@ func TestLinkAddressRequest(t *testing.T) { }) } } - -func TestLinkAddressRequestWithoutNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - }) - p := s.NetworkProtocolInstance(arp.ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } - - err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}) - if _, ok := err.(*tcpip.ErrNotConnected); !ok { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{}) - } -} diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index 036fdf739..65c708ac4 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -34,55 +34,13 @@ func (t *testInterface) ID() tcpip.NICID { return t.nicID } -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEndpointAfterClose { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - func TestMultiCounterStatsInitialization(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // expected to be bound by a MultiCounterStat. refStack := s.Stats() diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 1af87d713..243738951 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -84,7 +84,7 @@ type Fragmentation struct { lowLimit int reassemblers map[FragmentID]*reassembler rList reassemblerList - size int + memSize int timeout time.Duration blockSize uint16 clock tcpip.Clock @@ -156,22 +156,22 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // the protocol to identify a fragment. func (f *Fragmentation) Process( id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( - buffer.VectorisedView, uint8, bool, error) { + *stack.PacketBuffer, uint8, bool, error) { if first > last { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) } if first%f.blockSize != 0 { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) } fragmentSize := last - first + 1 if more && fragmentSize%f.blockSize != 0 { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } if l := pkt.Data.Size(); l != int(fragmentSize) { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } f.mu.Lock() @@ -190,24 +190,24 @@ func (f *Fragmentation) Process( } f.mu.Unlock() - res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt) + resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() f.release(r, false /* timedOut */) f.mu.Unlock() - return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) + return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() - f.size += consumed + f.memSize += memConsumed if done { f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. - if f.size > f.highLimit { - for f.size > f.lowLimit { + if f.memSize > f.highLimit { + for f.memSize > f.lowLimit { tail := f.rList.Back() if tail == nil { break @@ -216,7 +216,7 @@ func (f *Fragmentation) Process( } } f.mu.Unlock() - return res, firstFragmentProto, done, nil + return resPkt, firstFragmentProto, done, nil } func (f *Fragmentation) release(r *reassembler, timedOut bool) { @@ -228,10 +228,10 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) { delete(f.reassemblers, r.id) f.rList.Remove(r) - f.size -= r.size - if f.size < 0 { - log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) - f.size = 0 + f.memSize -= r.memSize + if f.memSize < 0 { + log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize) + f.memSize = 0 } if h := f.timeoutHandler; timedOut && h != nil { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 3a79688a8..905bbc19b 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -16,7 +16,6 @@ package fragmentation import ( "errors" - "reflect" "testing" "time" @@ -112,20 +111,20 @@ func TestFragmentationProcess(t *testing.T) { f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) + resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.pkt, err) } - if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)", - in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView()) - } if done != c.out[i].done { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) } if c.out[i].done { + if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" { + t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) + } if firstFragmentProto != proto { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) @@ -173,9 +172,17 @@ func TestReassemblingTimeout(t *testing.T) { // reassembly is done after the fragment is processd. expectDone bool - // sizeAfterEvent is the expected size of the fragmentation instance after - // the event. - sizeAfterEvent int + // memSizeAfterEvent is the expected memory size of the fragmentation + // instance after the event. + memSizeAfterEvent int + } + + memSizeOfFrags := func(frags ...*fragment) int { + var size int + for _, frag := range frags { + size += pkt(len(frag.data), frag.data).MemSize() + } + return size } half1 := &fragment{first: 0, last: 0, more: true, data: "0"} @@ -189,16 +196,16 @@ func TestReassemblingTimeout(t *testing.T) { name: "half1 and half2 are reassembled successfully", events: []event{ { - name: "half1", - fragment: half1, - expectDone: false, - sizeAfterEvent: 1, + name: "half1", + fragment: half1, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half2", - fragment: half2, - expectDone: true, - sizeAfterEvent: 0, + name: "half2", + fragment: half2, + expectDone: true, + memSizeAfterEvent: 0, }, }, }, @@ -206,36 +213,36 @@ func TestReassemblingTimeout(t *testing.T) { name: "half1 timeout, half2 timeout", events: []event{ { - name: "half1", - fragment: half1, - expectDone: false, - sizeAfterEvent: 1, + name: "half1", + fragment: half1, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half1 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - sizeAfterEvent: 1, + name: "half1 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half1 reassembly timeout", - clockAdvance: 1, - sizeAfterEvent: 0, + name: "half1 reassembly timeout", + clockAdvance: 1, + memSizeAfterEvent: 0, }, { - name: "half2", - fragment: half2, - expectDone: false, - sizeAfterEvent: 1, + name: "half2", + fragment: half2, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half2), }, { - name: "half2 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - sizeAfterEvent: 1, + name: "half2 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + memSizeAfterEvent: memSizeOfFrags(half2), }, { - name: "half2 reassembly timeout", - clockAdvance: 1, - sizeAfterEvent: 0, + name: "half2 reassembly timeout", + clockAdvance: 1, + memSizeAfterEvent: 0, }, }, }, @@ -255,8 +262,8 @@ func TestReassemblingTimeout(t *testing.T) { t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) } } - if got, want := f.size, event.sizeAfterEvent; got != want { - t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want) + if got, want := f.memSize, event.memSizeAfterEvent; got != want { + t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) } } }) @@ -264,7 +271,9 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil) + lowLimit := pkt(1, "0").MemSize() + highLimit := 3 * lowLimit // Allow at most 3 such packets. + f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) // Send first fragment with id = 1. @@ -288,15 +297,14 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil) + memSize := pkt(1, "0").MemSize() + f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) // Send the same packet again. f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - got := f.size - want := 1 - if got != want { + if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) } } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9b20bb1d8..933d63d32 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -29,13 +28,15 @@ type hole struct { last uint16 filled bool final bool - data buffer.View + // pkt is the fragment packet if hole is filled. We keep the whole pkt rather + // than the fragmented payload to prevent binding to specific buffer types. + pkt *stack.PacketBuffer } type reassembler struct { reassemblerEntry id FragmentID - size int + memSize int proto uint8 mu sync.Mutex holes []hole @@ -59,18 +60,18 @@ func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { return r } -func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() if r.done { // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, 0, nil + return nil, 0, false, 0, nil } var holeFound bool - var consumed int + var memConsumed int for i := range r.holes { currentHole := &r.holes[i] @@ -90,12 +91,12 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 if first < currentHole.first || currentHole.last < last { // Incoming fragment only partially fits in the free hole. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap + return nil, 0, false, 0, ErrFragmentOverlap } if !more { if !currentHole.final || currentHole.filled && currentHole.last != last { // We have another final fragment, which does not perfectly overlap. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + return nil, 0, false, 0, ErrFragmentConflict } } @@ -124,16 +125,15 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s }) currentHole.final = false } - v := pkt.Data.ToOwnedView() - consumed = v.Size() - r.size += consumed + memConsumed = pkt.MemSize() + r.memSize += memConsumed // Update the current hole to precisely match the incoming fragment. r.holes[i] = hole{ first: first, last: last, filled: true, final: currentHole.final, - data: v, + pkt: pkt, } r.filled++ // For IPv6, it is possible to have different Protocol values between @@ -153,25 +153,24 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s } if !holeFound { // Incoming fragment is beyond end. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + return nil, 0, false, 0, ErrFragmentConflict } // Check if all the holes have been filled and we are ready to reassemble. if r.filled < len(r.holes) { - return buffer.VectorisedView{}, 0, false, consumed, nil + return nil, 0, false, memConsumed, nil } sort.Slice(r.holes, func(i, j int) bool { return r.holes[i].first < r.holes[j].first }) - var size int - views := make([]buffer.View, 0, len(r.holes)) - for _, hole := range r.holes { - views = append(views, hole.data) - size += hole.data.Size() + resPkt := r.holes[0].pkt + for i := 1; i < len(r.holes); i++ { + fragPkt := r.holes[i].pkt + fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size()) } - return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil + return resPkt, r.proto, true, memConsumed, nil } func (r *reassembler) checkDoneOrMark() bool { diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index 2ff03eeeb..214a93709 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -15,6 +15,7 @@ package fragmentation import ( + "bytes" "math" "testing" @@ -44,16 +45,21 @@ func TestReassemblerProcess(t *testing.T) { return payload } - pkt := func(size int) *stack.PacketBuffer { + pkt := func(sizes ...int) *stack.PacketBuffer { + var vv buffer.VectorisedView + for _, size := range sizes { + vv.AppendView(v(size)) + } return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v(size).ToVectorisedView(), + Data: vv, }) } var tests = []struct { - name string - params []processParams - want []hole + name string + params []processParams + want []hole + wantPkt *stack.PacketBuffer }{ { name: "No fragments", @@ -64,7 +70,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment at beginning", params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, {first: 2, last: math.MaxUint16, filled: false, final: true}, }, }, @@ -72,7 +78,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment in the middle", params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true, final: false, data: v(2)}, + {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, {first: 0, last: 0, filled: false, final: false}, {first: 3, last: math.MaxUint16, filled: false, final: true}, }, @@ -81,7 +87,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment at the end", params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true, final: true, data: v(2)}, + {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, {first: 0, last: 0, filled: false}, }, }, @@ -89,8 +95,9 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment completing a packet", params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: true}, }, + wantPkt: pkt(2), }, { name: "Two fragments completing a packet", @@ -99,9 +106,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, - {first: 2, last: 3, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: false}, + {first: 2, last: 3, filled: true, final: true}, }, + wantPkt: pkt(2, 2), }, { name: "Two fragments completing a packet with a duplicate", @@ -111,9 +119,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, - {first: 2, last: 3, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: false}, + {first: 2, last: 3, filled: true, final: true}, }, + wantPkt: pkt(2, 2), }, { name: "Two fragments completing a packet with a partial duplicate", @@ -123,9 +132,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 3, filled: true, final: false, data: v(4)}, - {first: 4, last: 5, filled: true, final: true, data: v(2)}, + {first: 0, last: 3, filled: true, final: false}, + {first: 4, last: 5, filled: true, final: true}, }, + wantPkt: pkt(4, 2), }, { name: "Two overlapping fragments", @@ -134,7 +144,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, }, want: []hole{ - {first: 0, last: 10, filled: true, final: false, data: v(11)}, + {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, {first: 11, last: math.MaxUint16, filled: false, final: true}, }, }, @@ -145,7 +155,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, }, want: []hole{ - {first: 10, last: 14, filled: true, final: true, data: v(5)}, + {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, {first: 0, last: 9, filled: false, final: false}, }, }, @@ -156,7 +166,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, }, want: []hole{ - {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, {first: 0, last: 4, filled: false, final: false}, }, }, @@ -167,7 +177,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, }, want: []hole{ - {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, {first: 0, last: 4, filled: false, final: false}, }, }, @@ -176,14 +186,47 @@ func TestReassemblerProcess(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { r := newReassembler(FragmentID{}, &faketime.NullClock{}) + var resPkt *stack.PacketBuffer + var isDone bool for _, param := range test.params { - _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) if done != param.wantDone || err != param.wantError { t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) } + if done { + resPkt = pkt + isDone = true + } + } + + ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } + cmpPktData := func(a, b *stack.PacketBuffer) bool { + if a == nil || b == nil { + return a == b + } + return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView()) } - if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + + if isDone { + if diff := cmp.Diff( + test.want, r.holes, + cmp.AllowUnexported(hole{}), + // Do not compare pkt in hole. Data will be altered. + cmp.Comparer(ignorePkt), + ); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { + t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) + } + } else { + if diff := cmp.Diff( + test.want, r.holes, + cmp.AllowUnexported(hole{}), + cmp.Comparer(cmpPktData), + ); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } } }) } diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index a81f5c8c3..b9f129728 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -126,6 +126,16 @@ type multicastGroupState struct { // // Must not be nil. delayedReportJob *tcpip.Job + + // delyedReportJobFiresAt is the time when the delayed report job will fire. + // + // A zero value indicates that the job is not scheduled. + delayedReportJobFiresAt time.Time +} + +func (m *multicastGroupState) cancelDelayedReportJob() { + m.delayedReportJob.Cancel() + m.delayedReportJobFiresAt = time.Time{} } // GenericMulticastProtocolOptions holds options for the generic multicast @@ -428,7 +438,7 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { - info.delayedReportJob.Cancel() + info.cancelDelayedReportJob() info.lastToSendReport = false info.state = idleMember g.memberships[groupAddress] = info @@ -603,7 +613,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress return } - info.delayedReportJob.Cancel() + info.cancelDelayedReportJob() g.maybeSendLeave(groupAddress, info.lastToSendReport) info.lastToSendReport = false info.state = nonMember @@ -645,14 +655,24 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // If a timer for any address is already running, it is reset to the new // random value only if the requested Maximum Response Delay is less than // the remaining value of the running timer. + now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds()) if info.state == delayingMember { - // TODO: Reset the timer if time remaining is greater than maxResponseTime. - return + if info.delayedReportJobFiresAt.IsZero() { + panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress)) + } + + if info.delayedReportJobFiresAt.Sub(now) <= maxResponseTime { + // The timer is scheduled to fire before the maximum response time so we + // leave our timer as is. + return + } } info.state = delayingMember - info.delayedReportJob.Cancel() - info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) + info.cancelDelayedReportJob() + maxResponseTime = g.calculateDelayTimerDuration(maxResponseTime) + info.delayedReportJob.Schedule(maxResponseTime) + info.delayedReportJobFiresAt = now.Add(maxResponseTime) } // calculateDelayTimerDuration returns a random time between (0, maxRespTime]. diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index d5d5a449e..60eaea37e 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -408,40 +408,46 @@ func TestHandleReport(t *testing.T) { func TestHandleQuery(t *testing.T) { tests := []struct { - name string - queryAddr tcpip.Address - maxDelay time.Duration - expectReportsFor []tcpip.Address + name string + queryAddr tcpip.Address + maxDelay time.Duration + expectQueriedReportsFor []tcpip.Address + expectDelayedReportsFor []tcpip.Address }{ { - name: "Unpecified empty", - queryAddr: "", - maxDelay: 0, - expectReportsFor: []tcpip.Address{addr1, addr2}, + name: "Unpecified empty", + queryAddr: "", + maxDelay: 0, + expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, + expectDelayedReportsFor: nil, }, { - name: "Unpecified any", - queryAddr: "\x00", - maxDelay: 1, - expectReportsFor: []tcpip.Address{addr1, addr2}, + name: "Unpecified any", + queryAddr: "\x00", + maxDelay: 1, + expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, + expectDelayedReportsFor: nil, }, { - name: "Specified", - queryAddr: addr1, - maxDelay: 2, - expectReportsFor: []tcpip.Address{addr1}, + name: "Specified", + queryAddr: addr1, + maxDelay: 2, + expectQueriedReportsFor: []tcpip.Address{addr1}, + expectDelayedReportsFor: []tcpip.Address{addr2}, }, { - name: "Specified all-nodes", - queryAddr: addr3, - maxDelay: 3, - expectReportsFor: nil, + name: "Specified all-nodes", + queryAddr: addr3, + maxDelay: 3, + expectQueriedReportsFor: nil, + expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, }, { - name: "Specified other", - queryAddr: addr4, - maxDelay: 4, - expectReportsFor: nil, + name: "Specified other", + queryAddr: addr4, + maxDelay: 4, + expectQueriedReportsFor: nil, + expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, }, } @@ -469,20 +475,20 @@ func TestHandleQuery(t *testing.T) { if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Receiving a query should make us schedule a new delayed report if it - // is a query directed at us or a general query. + // Receiving a query should make us reschedule our delayed report timer + // to some time within the new max response delay. mgp.handleQuery(test.queryAddr, test.maxDelay) - if len(test.expectReportsFor) != 0 { - clock.Advance(test.maxDelay) - if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + clock.Advance(test.maxDelay) + if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The groups that were not affected by the query should still send a + // report after the max unsolicited report delay. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 47cce79bb..6a1f11a36 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -59,6 +59,14 @@ var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ PrefixLen: 120, } +type transportError struct { + origin tcpip.SockErrOrigin + typ uint8 + code uint8 + info uint32 + kind stack.TransportErrorKind +} + // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. // The former is used to pretend that it's a link endpoint so that we can // inspect packets written by the network endpoints. The latter is used to @@ -74,8 +82,7 @@ type testObject struct { srcAddr tcpip.Address dstAddr tcpip.Address v4 bool - typ stack.ControlType - extra uint32 + transErr transportError dataCalls int controlCalls int @@ -119,16 +126,23 @@ func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumb return stack.TransportPacketHandled } -// DeliverTransportControlPacket is called by network endpoints after parsing +// DeliverTransportError is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) - if typ != t.typ { - t.t.Errorf("typ = %v, want %v", typ, t.typ) - } - if extra != t.extra { - t.t.Errorf("extra = %v, want %v", extra, t.extra) + if diff := cmp.Diff( + t.transErr, + transportError{ + origin: transErr.Origin(), + typ: transErr.Type(), + code: transErr.Code(), + info: transErr.Info(), + kind: transErr.Kind(), + }, + cmp.AllowUnexported(transportError{}), + ); diff != "" { + t.t.Errorf("transport error mismatch (-want +got):\n%s", diff) } t.controlCalls++ } @@ -311,6 +325,14 @@ func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.N return &tcpip.ErrNotSupported{} } +func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { + return nil +} + +func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { + return nil +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -465,7 +487,7 @@ func TestEnableWhenNICDisabled(t *testing.T) { // We pass nil for all parameters except the NetworkInterface and Stack // since Enable only depends on these. - ep := p.NewEndpoint(&nic, nil, nil, nil) + ep := p.NewEndpoint(&nic, nil) // The endpoint should initially be disabled, regardless the NIC's enabled // status. @@ -527,7 +549,7 @@ func TestIPv4Send(t *testing.T) { v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, nil) + ep := proto.NewEndpoint(&nic, nil) defer ep.Close() // Allocate and initialize the payload view. @@ -661,7 +683,7 @@ func TestReceive(t *testing.T) { v4: test.v4, }, } - ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -694,24 +716,81 @@ func TestReceive(t *testing.T) { } func TestIPv4ReceiveControl(t *testing.T) { - const mtu = 0xbeef - header.IPv4MinimumSize + const ( + mtu = 0xbeef - header.IPv4MinimumSize + dataLen = 8 + ) + cases := []struct { name string expectedCount int fragmentOffset uint16 code header.ICMPv4Code - expectedTyp stack.ControlType - expectedExtra uint32 + transErr transportError trunc int }{ - {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, - {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, - {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, + { + name: "FragmentationNeeded", + expectedCount: 1, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4FragmentationNeeded), + info: mtu, + kind: stack.PacketTooBigTransportError, + }, + trunc: 0, + }, + { + name: "Truncated (missing IPv4 header)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize, + }, + { + name: "Truncated (partial offending packet's IP header)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1, + }, + { + name: "Truncated (partial offending packet's data)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1, + }, + { + name: "Port unreachable", + expectedCount: 1, + fragmentOffset: 0, + code: header.ICMPv4PortUnreachable, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4PortUnreachable), + kind: stack.DestinationPortUnreachableTransportError, + }, + trunc: 0, + }, + { + name: "Non-zero fragment offset", + expectedCount: 0, + fragmentOffset: 100, + code: header.ICMPv4PortUnreachable, + trunc: 0, + }, + { + name: "Zero-length packet", + expectedCount: 0, + fragmentOffset: 100, + code: header.ICMPv4PortUnreachable, + trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen, + }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -722,7 +801,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -730,7 +809,7 @@ func TestIPv4ReceiveControl(t *testing.T) { } const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize - view := buffer.NewView(dataOffset + 8) + view := buffer.NewView(dataOffset + dataLen) // Create the outer IPv4 header. ip := header.IPv4(view) @@ -777,8 +856,7 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.srcAddr = remoteIPv4Addr nic.testObject.dstAddr = localIPv4Addr nic.testObject.contents = view[dataOffset:] - nic.testObject.typ = c.expectedTyp - nic.testObject.extra = c.expectedExtra + nic.testObject.transErr = c.transErr addressableEndpoint, ok := ep.(stack.AddressableEndpoint) if !ok { @@ -811,7 +889,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -906,7 +984,7 @@ func TestIPv6Send(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, nil) + ep := proto.NewEndpoint(&nic, nil) defer ep.Close() if err := ep.Enable(); err != nil { @@ -945,30 +1023,112 @@ func TestIPv6Send(t *testing.T) { } func TestIPv6ReceiveControl(t *testing.T) { + const ( + mtu = 0xffff + outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" + dataLen = 8 + ) + newUint16 := func(v uint16) *uint16 { return &v } - const mtu = 0xffff - const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" + portUnreachableTransErr := transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6DstUnreachable), + code: uint8(header.ICMPv6PortUnreachable), + kind: stack.DestinationPortUnreachableTransportError, + } + cases := []struct { name string expectedCount int fragmentOffset *uint16 typ header.ICMPv6Type code header.ICMPv6Code - expectedTyp stack.ControlType - expectedExtra uint32 + transErr transportError trunc int }{ - {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, - {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, - {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, - {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, + { + name: "PacketTooBig", + expectedCount: 1, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6PacketTooBig), + code: uint8(header.ICMPv6UnusedCode), + info: mtu, + kind: stack.PacketTooBigTransportError, + }, + trunc: 0, + }, + { + name: "Truncated (missing offending packet's IPv6 header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize, + }, + { + name: "Truncated PacketTooBig (partial offending packet's IPv6 header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1, + }, + { + name: "Truncated (partial offending packet's data)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1, + }, + { + name: "Port unreachable", + expectedCount: 1, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "Truncated DstPortUnreachable (partial offending packet's IP header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1, + }, + { + name: "DstPortUnreachable for Fragmented, zero offset", + expectedCount: 1, + fragmentOffset: newUint16(0), + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "DstPortUnreachable for Non-zero fragment offset", + expectedCount: 0, + fragmentOffset: newUint16(100), + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "Zero-length packet", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen, + }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -979,7 +1139,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -990,7 +1150,7 @@ func TestIPv6ReceiveControl(t *testing.T) { if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize } - view := buffer.NewView(dataOffset + 8) + view := buffer.NewView(dataOffset + dataLen) // Create the outer IPv6 header. ip := header.IPv6(view) @@ -1041,8 +1201,7 @@ func TestIPv6ReceiveControl(t *testing.T) { nic.testObject.srcAddr = remoteIPv6Addr nic.testObject.dstAddr = localIPv6Addr nic.testObject.contents = view[dataOffset:] - nic.testObject.typ = c.expectedTyp - nic.testObject.extra = c.expectedExtra + nic.testObject.transErr = c.transErr // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 330a7d170..9713c4448 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -44,6 +44,7 @@ go_test( "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3d93a2cd0..74e70e283 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -23,11 +23,108 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// icmpv4DestinationUnreachableSockError is a general ICMPv4 Destination +// Unreachable error. +// +// +stateify savable +type icmpv4DestinationUnreachableSockError struct{} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Type() uint8 { + return uint8(header.ICMPv4DstUnreachable) +} + +// Info implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Info() uint32 { + return 0 +} + +var _ stack.TransportError = (*icmpv4DestinationHostUnreachableSockError)(nil) + +// icmpv4DestinationHostUnreachableSockError is an ICMPv4 Destination Host +// Unreachable error. +// +// It indicates that a packet was not able to reach the destination host. +// +// +stateify savable +type icmpv4DestinationHostUnreachableSockError struct { + icmpv4DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4DestinationHostUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv4HostUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv4DestinationHostUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationHostUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv4DestinationPortUnreachableSockError)(nil) + +// icmpv4DestinationPortUnreachableSockError is an ICMPv4 Destination Port +// Unreachable error. +// +// It indicates that a packet reached the destination host, but the transport +// protocol was not active on the destination port. +// +// +stateify savable +type icmpv4DestinationPortUnreachableSockError struct { + icmpv4DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4DestinationPortUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv4PortUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv4DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationPortUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv4FragmentationNeededSockError)(nil) + +// icmpv4FragmentationNeededSockError is an ICMPv4 Destination Unreachable error +// due to fragmentation being required but the packet was set to not be +// fragmented. +// +// It indicates that a link exists on the path to the destination with an MTU +// that is too small to carry the packet. +// +// +stateify savable +type icmpv4FragmentationNeededSockError struct { + icmpv4DestinationUnreachableSockError + + mtu uint32 +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4FragmentationNeededSockError) Code() uint8 { + return uint8(header.ICMPv4FragmentationNeeded) +} + +// Info implements tcpip.SockErrorCause. +func (e *icmpv4FragmentationNeededSockError) Info() uint32 { + return e.mtu +} + +// Kind implements stack.TransportError. +func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind { + return stack.PacketTooBigTransportError +} + // handleControl handles the case when an ICMP error packet contains the headers // of the original packet that caused the ICMP one to be sent. This information // is used to find out which transport endpoint must be notified about the ICMP // packet. We only expect the payload, not the enclosing ICMP packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -54,10 +151,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack return } - // Skip the ip header, then deliver control message. + // Skip the ip header, then deliver the error. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { @@ -222,19 +319,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4HostUnreachable: - e.handleControl(stack.ControlNoRoute, 0, pkt) - + e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, pkt) - + e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt) case header.ICMPv4FragmentationNeeded: networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) if err != nil { networkMTU = 0 } - e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) + e.handleControl(&icmpv4FragmentationNeededSockError{mtu: networkMTU}, pkt) } - case header.ICMPv4SrcQuench: received.srcQuench.Increment() diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 4cd0b3256..acc126c3b 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -215,6 +215,11 @@ func (igmp *igmpState) setV1Present(v bool) { } } +func (igmp *igmpState) resetV1Present() { + igmp.igmpV1Job.Cancel() + igmp.setV1Present(false) +} + // handleMembershipQuery handles a membership query. // // Precondition: igmp.ep.mu must be locked. diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 1ee573ac8..95fd75ab7 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -101,10 +101,10 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma }) } -// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is -// present on the network. The IGMP stack will then send IGMPv1 Membership -// reports for backwards compatibility. -func TestIgmpV1Present(t *testing.T) { +// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1 +// router is detected. V1 present status is expected to be reset when the NIC +// cycles. +func TestIGMPV1Present(t *testing.T) { e, s, clock := createStack(t, true) if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) @@ -116,14 +116,16 @@ func TestIgmpV1Present(t *testing.T) { // This NIC will send an IGMPv2 report immediately, before this test can get // the IGMPv1 General Membership Query in. - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + { + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) if t.Failed() { t.FailNow() } @@ -145,19 +147,38 @@ func TestIgmpV1Present(t *testing.T) { // Verify the solicited Membership Report is sent. Now that this NIC has seen // an IGMPv1 query, it should send an IGMPv1 Membership Report. - p, ok = e.Read() - if ok { + if p, ok := e.Read(); ok { t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) } clock.Advance(ipv4.UnsolicitedReportIntervalMax) - p, ok = e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + { + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) + } + + // Cycling the interface should reset the V1 present flag. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + { + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) } func TestSendQueuedIGMPReports(t *testing.T) { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e5c80699d..b2d626107 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -101,11 +101,11 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { // Use the same control type as an ICMPv4 destination host unreachable error // since the host is considered unreachable if we cannot resolve the link // address to the next hop. - e.handleControl(stack.ControlNoRoute, 0, pkt) + e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, dispatcher: dispatcher, @@ -229,6 +229,12 @@ func (e *endpoint) disableLocked() { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + // Reset the IGMP V1 present flag. + // + // If the node comes back up on the same network, it will re-learn that it + // needs to perform IGMPv1. + e.mu.igmp.resetV1Present() + if !e.setEnabled(false) { panic("should have only done work to disable the endpoint if it was enabled") } @@ -734,7 +740,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } proto := h.Protocol() - data, _, ready, err := e.protocol.fragmentation.Process( + resPkt, _, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -757,7 +763,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { if !ready { return } - pkt.Data = data + pkt = resPkt + h = header.IPv4(pkt.NetworkHeader().View()) // The reassembler doesn't take care of fixing up the header, so we need // to do it here. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ed5899f0b..a296bed79 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -38,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -2058,7 +2059,7 @@ func TestReceiveFragments(t *testing.T) { // the fragment block size of 8 (RFC 791 section 3.1 page 14). ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] - // Used to test the max reassembled payload length (65,535 octets). + // Used to test the max reassembled IPv4 payload length. ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] @@ -2406,6 +2407,7 @@ func TestReceiveFragments(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + RawFactory: raw.EndpointFactory{}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { @@ -2431,6 +2433,13 @@ func TestReceiveFragments(t *testing.T) { t.Fatalf("Bind(%+v): %s", bindAddr, err) } + // Bring up a raw endpoint so we can examine network headers. + epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */) + if err != nil { + t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) + } + defer epRaw.Close() + // Prepare and send the fragments. for _, frag := range test.fragments { hdr := buffer.NewPrependable(header.IPv4MinimumSize) @@ -2462,10 +2471,11 @@ func TestReceiveFragments(t *testing.T) { } for i, expectedPayload := range test.expectedPayloads { + // Check UDP payload delivered by UDP endpoint. var buf bytes.Buffer result, err := ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("(i=%d) Read: %s", i, err) + t.Fatalf("(i=%d) ep.Read: %s", i, err) } if diff := cmp.Diff(tcpip.ReadResult{ Count: len(expectedPayload), @@ -2474,7 +2484,24 @@ func TestReceiveFragments(t *testing.T) { t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) } if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) + t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff) + } + + // Check IPv4 header in packet delivered by raw endpoint. + buf.Reset() + result, err = epRaw.Read(&buf, tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("(i=%d) epRaw.Read: %s", i, err) + } + // Reassambly does not take care of checksum. Here we write our own + // check routine instead of using checker.IPv4. + ip := header.IPv4(buf.Bytes()) + for _, check := range []checker.NetworkChecker{ + checker.FragmentFlags(0), + checker.FragmentOffset(0), + checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))), + } { + check(t, []header.Network{ip}) } } diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index b28e7dcde..fbbc6e69c 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -50,7 +50,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) var nicIDs []tcpip.NICID proto.mu.Lock() @@ -82,7 +82,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // expected to be bound by a MultiCounterStat. refStack := s.Stats() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 7298bd061..dcfd93bab 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -23,11 +23,136 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// icmpv6DestinationUnreachableSockError is a general ICMPv6 Destination +// Unreachable error. +// +// +stateify savable +type icmpv6DestinationUnreachableSockError struct{} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP6 +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Type() uint8 { + return uint8(header.ICMPv6DstUnreachable) +} + +// Info implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Info() uint32 { + return 0 +} + +var _ stack.TransportError = (*icmpv6DestinationNetworkUnreachableSockError)(nil) + +// icmpv6DestinationNetworkUnreachableSockError is an ICMPv6 Destination Network +// Unreachable error. +// +// It indicates that the destination network is unreachable. +// +// +stateify savable +type icmpv6DestinationNetworkUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationNetworkUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6NetworkUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationNetworkUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationNetworkUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6DestinationPortUnreachableSockError)(nil) + +// icmpv6DestinationPortUnreachableSockError is an ICMPv6 Destination Port +// Unreachable error. +// +// It indicates that a packet reached the destination host, but the transport +// protocol was not active on the destination port. +// +// +stateify savable +type icmpv6DestinationPortUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationPortUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6PortUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationPortUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6DestinationAddressUnreachableSockError)(nil) + +// icmpv6DestinationAddressUnreachableSockError is an ICMPv6 Destination Address +// Unreachable error. +// +// It indicates that a packet was not able to reach the destination. +// +// +stateify savable +type icmpv6DestinationAddressUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationAddressUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6AddressUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationAddressUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationHostUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6PacketTooBigSockError)(nil) + +// icmpv6PacketTooBigSockError is an ICMPv6 Packet Too Big error. +// +// It indicates that a link exists on the path to the destination with an MTU +// that is too small to carry the packet. +// +// +stateify savable +type icmpv6PacketTooBigSockError struct { + mtu uint32 +} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP6 +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Type() uint8 { + return uint8(header.ICMPv6PacketTooBig) +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Code() uint8 { + return uint8(header.ICMPv6UnusedCode) +} + +// Info implements tcpip.SockErrorCause. +func (e *icmpv6PacketTooBigSockError) Info() uint32 { + return e.mtu +} + +// Kind implements stack.TransportError. +func (*icmpv6PacketTooBigSockError) Kind() stack.TransportErrorKind { + return stack.PacketTooBigTransportError +} + // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -67,8 +192,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack p = fragHdr.TransportProtocol() } - // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportError(src, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -175,7 +299,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { if err != nil { networkMTU = 0 } - e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) + e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() @@ -187,11 +311,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { case header.ICMPv6NetworkUnreachable: - e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationPortUnreachableSockError{}, pkt) } - case header.ICMPv6NeighborSolicit: received.neighborSolicit.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { @@ -289,10 +412,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } else if unspecifiedSource { received.invalid.Increment() return - } else if e.nud != nil { - e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr) + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } // As per RFC 4861 section 7.1.1: @@ -458,18 +585,17 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - if e.nud == nil { - if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr) - } - return - } - - e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + switch err := e.nic.HandleNeighborConfirmation(ProtocolNumber, targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ Solicited: na.SolicitedFlag(), Override: na.OverrideFlag(), IsRouter: na.RouterFlag(), - }) + }); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err)) + } case header.ICMPv6EchoRequest: received.echoRequest.Increment() @@ -575,10 +701,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } - if e.nud != nil { - // A RS with a specified source IP address modifies the NUD state - // machine in the same way a reachability probe would. - e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol) + // A RS with a specified source IP address modifies the neighbor table + // in the same way a regular probe would. + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) } } @@ -627,8 +757,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. - if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol) + if len(sourceLinkAddr) != 0 { + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, routerAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } e.mu.Lock() @@ -694,24 +830,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } } -var _ stack.LinkAddressResolver = (*protocol)(nil) - // LinkAddressProtocol implements stack.LinkAddressResolver. -func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { - nicID := nic.ID() - - p.mu.Lock() - netEP, ok := p.mu.eps[nicID] - p.mu.Unlock() - if !ok { - return &tcpip.ErrNotConnected{} - } - +func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { remoteAddr := targetAddr if len(remoteLinkAddr) == 0 { remoteAddr = header.SolicitedNodeAddr(targetAddr) @@ -719,22 +844,22 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { - addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) + addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) if addressEndpoint == nil { return &tcpip.ErrNetworkUnreachable{} } localAddr = addressEndpoint.AddressWithPrefix().Address - } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 { + } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 { return &tcpip.ErrBadLocalAddress{} } optsSerializer := header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), + header.NDPSourceLinkLayerAddressOption(e.nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) @@ -751,9 +876,9 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot panic(fmt.Sprintf("failed to add IP header: %s", err)) } - stat := netEP.stats.icmp.packetsSent + stat := e.stats.icmp.packetsSent - if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stat.dropped.Increment() return err } @@ -763,7 +888,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } // ResolveStaticAddress implements stack.LinkAddressResolver. -func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if header.IsV6MulticastAddress(addr) { return header.EthernetAddressFromMulticastIPv6Address(addr), true } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index db1c2e663..92f9ee2c2 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -93,35 +93,14 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st return stack.TransportPacketHandled } -var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil) - -type stubLinkAddressCache struct{} - -func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {} - -type stubNUDHandler struct { - probeCount int - confirmationCount int -} - -var _ stack.NUDHandler = (*stubNUDHandler)(nil) - -func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { - s.probeCount++ -} - -func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { - s.confirmationCount++ -} - -func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { -} - var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { stack.LinkEndpoint + probeCount int + confirmationCount int + nicID tcpip.NICID } @@ -160,6 +139,16 @@ func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gs return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } +func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { + t.probeCount++ + return nil +} + +func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { + t.confirmationCount++ + return nil +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -202,7 +191,7 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -360,7 +349,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -573,12 +562,19 @@ func newTestContext(t *testing.T) *testContext { }}, ) - return c -} + t.Cleanup(func() { + if err := c.s0.RemoveNIC(nicID); err != nil { + t.Errorf("c.s0.RemoveNIC(%d): %s", nicID, err) + } + if err := c.s1.RemoveNIC(nicID); err != nil { + t.Errorf("c.s1.RemoveNIC(%d): %s", nicID, err) + } -func (c *testContext) cleanup() { - c.linkEP0.Close() - c.linkEP1.Close() + c.linkEP0.Close() + c.linkEP1.Close() + }) + + return c } type routeArgs struct { @@ -628,7 +624,6 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. func TestLinkResolution(t *testing.T) { c := newTestContext(t) - defer c.cleanup() r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -1346,29 +1341,32 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - p := s.NetworkProtocolInstance(ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") - } linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) if err := s.CreateNIC(nicID, linkEP); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) + } + linkRes, ok := ep.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) + } + if len(test.nicAddr) != 0 { if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) } } - // We pass a test network interface to LinkAddressRequest with the same NIC - // ID and link endpoint used by the NIC we created earlier so that we can - // mock a link address request and observe the packets sent to the link - // endpoint even though the stack uses the real NIC. - err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) + { + err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) + } } if test.expectedErr != nil { @@ -1798,8 +1796,9 @@ func TestCallsToNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - nudHandler := &stubNUDHandler{} - ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + + testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)} + ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -1834,11 +1833,11 @@ func TestCallsToNeighborCache(t *testing.T) { ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. - if nudHandler.probeCount != test.wantProbeCount { - t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount) + if testInterface.probeCount != test.wantProbeCount { + t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount) } - if nudHandler.confirmationCount != test.wantConfirmationCount { - t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount) + if testInterface.confirmationCount != test.wantConfirmationCount { + t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount) } }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 94043ed4e..c2e8c3ea7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -164,6 +164,7 @@ func getLabel(addr tcpip.Address) uint8 { panic(fmt.Sprintf("should have a label for address = %s", addr)) } +var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) @@ -172,13 +173,11 @@ var _ stack.NDPEndpoint = (*endpoint)(nil) var _ NDPEndpoint = (*endpoint)(nil) type endpoint struct { - nic stack.NetworkInterface - linkAddrCache stack.LinkAddressCache - nud stack.NUDHandler - dispatcher stack.TransportDispatcher - protocol *protocol - stack *stack.Stack - stats sharedStats + nic stack.NetworkInterface + dispatcher stack.TransportDispatcher + protocol *protocol + stack *stack.Stack + stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. @@ -236,7 +235,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { }) pkt.NICID = e.nic.ID() pkt.NetworkProtocolNumber = ProtocolNumber - e.handleControl(stack.ControlAddressUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationAddressUnreachableSockError{}, pkt) } // onAddressAssignedLocked handles an address being assigned. @@ -1167,7 +1166,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. - data, proto, ready, err := e.protocol.fragmentation.Process( + resPkt, proto, ready, err := e.protocol.fragmentation.Process( // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ @@ -1188,7 +1187,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } if ready { - pkt.Data = data + pkt = resPkt // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC @@ -1732,13 +1731,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ - nic: nic, - linkAddrCache: linkAddrCache, - nud: nud, - dispatcher: dispatcher, - protocol: p, + nic: nic, + dispatcher: dispatcher, + protocol: p, } e.mu.Lock() e.mu.addressableEndpointState.Init(e) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 8248052a3..1c6c37c91 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -2599,7 +2599,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) var nicIDs []tcpip.NICID proto.mu.Lock() @@ -3075,7 +3075,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // supposed to be bound. refStack := s.Stats() diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 1d38b8b05..8edaa9508 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -63,7 +63,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) if err := ep.Enable(); err != nil { t.Fatalf("ep.Enable(): %s", err) } @@ -199,6 +199,7 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -337,18 +338,18 @@ func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *test Data: hdr.View().ToVectorisedView(), })) - neighbors, err := s.Neighbors(nicID) + neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) for _, n := range neighbors { if existing, ok := neighborByAddr[n.Addr]; ok { if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) } neighborByAddr[n.Addr] = n } @@ -758,8 +759,10 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseLinkAddrCache: true, }) e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -905,18 +908,18 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes Data: hdr.View().ToVectorisedView(), })) - neighbors, err := s.Neighbors(nicID) + neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) for _, n := range neighbors { if existing, ok := neighborByAddr[n.Addr]; ok { if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) } neighborByAddr[n.Addr] = n } @@ -1275,8 +1278,8 @@ func TestNeighborAdvertisementValidation(t *testing.T) { // There is no need to create an entry if none exists, since the // recipient has apparently not initiated any communication with the // target. - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } else if len(neighbors) != 0 { t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 019d6a63c..1e00144a5 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -473,6 +473,48 @@ func (origin SockErrOrigin) IsICMPErr() bool { return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6 } +// SockErrorCause is the cause of a socket error. +type SockErrorCause interface { + // Origin is the source of the error. + Origin() SockErrOrigin + + // Type is the origin specific type of error. + Type() uint8 + + // Code is the origin and type specific error code. + Code() uint8 + + // Info is any extra information about the error. + Info() uint32 +} + +// LocalSockError is a socket error that originated from the local host. +// +// +stateify savable +type LocalSockError struct { + info uint32 +} + +// Origin implements SockErrorCause. +func (*LocalSockError) Origin() SockErrOrigin { + return SockExtErrorOriginLocal +} + +// Type implements SockErrorCause. +func (*LocalSockError) Type() uint8 { + return 0 +} + +// Code implements SockErrorCause. +func (*LocalSockError) Code() uint8 { + return 0 +} + +// Info implements SockErrorCause. +func (l *LocalSockError) Info() uint32 { + return l.info +} + // SockError represents a queue entry in the per-socket error queue. // // +stateify savable @@ -481,14 +523,8 @@ type SockError struct { // Err is the error caused by the errant packet. Err Error - // ErrOrigin indicates the error origin. - ErrOrigin SockErrOrigin - // ErrType is the type in the ICMP header. - ErrType uint8 - // ErrCode is the code in the ICMP header. - ErrCode uint8 - // ErrInfo is additional info about the error. - ErrInfo uint32 + // Cause is the detailed cause of the error. + Cause SockErrorCause // Payload is the errant packet's payload. Payload []byte @@ -540,12 +576,11 @@ func (so *SocketOptions) QueueErr(err *SockError) { // QueueLocalErr queues a local error onto the local queue. func (so *SocketOptions) QueueLocalErr(err Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { so.QueueErr(&SockError{ - Err: err, - ErrOrigin: SockExtErrorOriginLocal, - ErrInfo: info, - Payload: payload, - Dst: dst, - NetProto: net, + Err: err, + Cause: &LocalSockError{info: info}, + Payload: payload, + Dst: dst, + NetProto: net, }) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bb30556cf..ee23c9b98 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -72,6 +72,7 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", + "packet_buffer_unsafe.go", "pending_packets.go", "rand.go", "registration.go", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 63a42a2ea..c24f56ece 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -41,6 +41,7 @@ const ( protocolNumberOffset = 2 ) +var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil) var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -153,7 +154,6 @@ type fwdTestNetworkEndpointStats struct{} // IsNetworkEndpointStats implements stack.NetworkEndpointStats. func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} -var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) // fwdTestNetworkProtocol is a network-layer protocol that implements Address @@ -161,10 +161,9 @@ var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) type fwdTestNetworkProtocol struct { stack *Stack - addrCache *linkAddrCache - neigh *neighborCache + neighborTable neighborTable addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) + onLinkAddressResolved func(neighborTable, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) mu struct { @@ -197,7 +196,7 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint { e := &fwdTestNetworkEndpoint{ nic: nic, proto: f, @@ -219,23 +218,23 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { - if f.onLinkAddressResolved != nil { - time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) +func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + if fn := f.proto.onLinkAddressResolved; fn != nil { + time.AfterFunc(f.proto.addrResolveDelay, func() { + fn(f.proto.neighborTable, addr, remoteLinkAddr) }) } return nil } -func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if f.onResolveStaticAddress != nil { - return f.onResolveStaticAddress(addr) +func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if fn := f.proto.onResolveStaticAddress; fn != nil { + return fn(addr) } return "", false } -func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } @@ -401,11 +400,9 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC if !ok { t.Fatal("NIC 2 does not exist") } - if useNeighborCache { - // Control the neighbor cache for NIC 2. - proto.neigh = nic.neigh - } else { - proto.addrCache = nic.linkAddrCache + + if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { + proto.neighborTable = l.neighborTable } // Route all packets to NIC 2. @@ -482,43 +479,35 @@ func TestForwardingWithFakeResolver(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any address will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any address will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -573,7 +562,7 @@ func TestForwardingWithNoResolver(t *testing.T) { func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { proto := &fwdTestNetworkProtocol{ addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + onLinkAddressResolved: func(neighborTable, tcpip.Address, tcpip.LinkAddress) { // Don't resolve the link address. }, } @@ -606,49 +595,38 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.AddLinkAddress(addr, "c") - } - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) } }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -693,43 +671,35 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -769,43 +739,35 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -864,38 +826,31 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 930b8f795..5b6b58b1d 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -24,8 +24,6 @@ import ( const linkAddrCacheSize = 512 // max cache entries -var _ LinkAddressCache = (*linkAddrCache)(nil) - // linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. // // The entries are stored in a ring buffer, oldest entry replaced first. @@ -34,6 +32,8 @@ var _ LinkAddressCache = (*linkAddrCache)(nil) type linkAddrCache struct { nic *NIC + linkRes LinkAddressResolver + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -140,7 +140,7 @@ func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { } // add adds a k -> v mapping to the cache. -func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { +func (c *linkAddrCache) add(k tcpip.Address, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is // relative to the time when information was learned, rather than when it // happened to be inserted into the cache. @@ -198,10 +198,10 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { return entry } -// get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { +// get reports any known link address for addr. +func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { c.mu.Lock() - entry := c.getOrCreateEntryLocked(k) + entry := c.getOrCreateEntryLocked(addr) entry.mu.Lock() defer entry.mu.Unlock() c.mu.Unlock() @@ -224,7 +224,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } if entry.mu.done == nil { entry.mu.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: @@ -232,11 +232,11 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } } -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic) + c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) select { case now := <-time.After(c.resolutionTimeout): @@ -280,13 +280,80 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt return true } -func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - c := &linkAddrCache{ +func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) { + *c = linkAddrCache{ nic: nic, + linkRes: linkRes, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } + + c.mu.Lock() c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) - return c + c.mu.Unlock() +} + +var _ neighborTable = (*linkAddrCache)(nil) + +func (*linkAddrCache) neighbors() ([]NeighborEntry, tcpip.Error) { + return nil, &tcpip.ErrNotSupported{} +} + +func (c *linkAddrCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + c.add(addr, linkAddr) +} + +func (*linkAddrCache) remove(addr tcpip.Address) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (*linkAddrCache) removeAll() tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + if len(linkAddr) != 0 { + // NUD allows probes without a link address but linkAddrCache + // is a simple neighbor table which does not implement NUD. + // + // As per RFC 4861 section 4.3, + // + // Source link-layer address + // The link-layer address for the sender. MUST NOT be + // included when the source IP address is the + // unspecified address. Otherwise, on link layers + // that have addresses this option MUST be included in + // multicast solicitations and SHOULD be included in + // unicast solicitations. + c.add(addr, linkAddr) + } +} + +func (c *linkAddrCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + if len(linkAddr) != 0 { + // NUD allows confirmations without a link address but linkAddrCache + // is a simple neighbor table which does not implement NUD. + // + // As per RFC 4861 section 4.4, + // + // Target link-layer address + // The link-layer address for the target, i.e., the + // sender of the advertisement. This option MUST be + // included on link layers that have addresses when + // responding to multicast solicitations. When + // responding to a unicast Neighbor Solicitation this + // option SHOULD be included. + c.add(addr, linkAddr) + } +} + +func (c *linkAddrCache) handleUpperLevelConfirmation(tcpip.Address) {} + +func (*linkAddrCache) nudConfig() (NUDConfigurations, tcpip.Error) { + return NUDConfigurations{}, &tcpip.ErrNotSupported{} +} + +func (*linkAddrCache) setNUDConfig(NUDConfigurations) tcpip.Error { + return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 466a5e8d9..9e7f331c9 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { @@ -60,7 +60,7 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { for _, ta := range testAddrs { if ta.addr == addr { - r.cache.AddLinkAddress(ta.addr, ta.linkAddr) + r.cache.add(ta.addr, ta.linkAddr) break } } @@ -77,10 +77,10 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { - got, ch, err := c.get(addr, linkRes, "", nil, nil) + got, ch, err := c.get(addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { return got, &tcpip.ErrTimeout{} @@ -100,27 +100,28 @@ func newEmptyNIC() *NIC { } func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] - c.AddLinkAddress(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, e.linkAddr) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) + t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { e := testAddrs[i] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) + t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. @@ -135,15 +136,16 @@ func TestCacheOverflow(t *testing.T) { } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes) var wg sync.WaitGroup for r := 0; r < 16; r++ { wg.Add(1) go func() { for _, e := range testAddrs { - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) } wg.Done() }() @@ -154,12 +156,12 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] @@ -171,38 +173,40 @@ func TestCacheConcurrent(t *testing.T) { } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes) e := testAddrs[0] - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - _, _, err := c.get(e.addr, linkRes, "", nil, nil) + _, _, err := c.get(e.addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err) + t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err) } } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) e := testAddrs[0] l2 := e.linkAddr + "2" - c.AddLinkAddress(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, e.linkAddr) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } - c.AddLinkAddress(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, l2) + got, _, err = c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != l2 { - t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2) } } @@ -213,34 +217,36 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes) for i, ta := range testAddrs { - got, err := getBlocking(c, ta.addr, linkRes) + got, err := getBlocking(&c, ta.addr) if err != nil { - t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) + t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes) var requestCount uint32 linkRes.onLinkAddressRequest = func() { @@ -249,20 +255,20 @@ func TestCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working... e := testAddrs[0] - got, err := getBlocking(c, e.addr, linkRes) + got, err := getBlocking(&c, e.addr) if err != nil { - t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) + t.Errorf("getBlocking(_, %s): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) e.addr += "2" - a, err := getBlocking(c, e.addr, linkRes) + a, err := getBlocking(&c, e.addr) if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -273,12 +279,13 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) - linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay} + c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes) e := testAddrs[0] - a, err := getBlocking(c, e.addr, linkRes) + a, err := getBlocking(&c, e.addr) if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 64383bc7c..0238605af 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2796,14 +2796,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NIC: nicID, }}) - if useNeighborCache { - if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) - } - } else { - if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) - } + if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } return ndpDisp, e, s } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 88a3ff776..7e3132058 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -42,11 +42,10 @@ type NeighborStats struct { // 2. Static entries are explicitly added by a user and have no expiration. // Their state is always Static. The amount of static entries stored in the // cache is unbounded. -// -// neighborCache implements NUDHandler. type neighborCache struct { - nic *NIC - state *NUDState + nic *NIC + state *NUDState + linkRes LinkAddressResolver // mu protects the fields below. mu sync.RWMutex @@ -62,8 +61,6 @@ type neighborCache struct { } } -var _ NUDHandler = (*neighborCache)(nil) - // getOrCreateEntry retrieves a cache entry associated with addr. The // returned entry is always refreshed in the cache (it is reachable via the // map, and its place is bumped in LRU). @@ -73,7 +70,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -89,7 +86,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) + entry := newNeighborEntry(n, remoteAddr, n.state) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -126,8 +123,8 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() defer entry.mu.Unlock() @@ -206,7 +203,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state) } // removeEntry removes a dynamic or static entry by address from the neighbor @@ -263,27 +260,45 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { n.state.SetConfig(config) } -// HandleProbe implements NUDHandler.HandleProbe by following the logic defined -// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled -// by the caller. -func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +var _ neighborTable = (*neighborCache)(nil) + +func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { + return n.entries(), nil +} + +func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + entry, ch, err := n.entry(addr, localAddr, onResolve) + return entry.LinkAddr, ch, err +} + +func (n *neighborCache) remove(addr tcpip.Address) tcpip.Error { + if !n.removeEntry(addr) { + return &tcpip.ErrBadAddress{} + } + + return nil +} + +func (n *neighborCache) removeAll() tcpip.Error { + n.clear() + return nil +} + +// handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. +// +// Validation of the probe is expected to be handled by the caller. +func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() } -// HandleConfirmation implements NUDHandler.HandleConfirmation by following the -// logic defined in RFC 4861 section 7.2.5. +// handleConfirmation handles a neighbor confirmation as defined by +// RFC 4861 section 7.2.5. // -// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other -// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol -// should be deployed where preventing access to the broadcast segment might -// not be possible. SEND uses RSA key pairs to produce cryptographically -// generated addresses, as defined in RFC 3972, Cryptographically Generated -// Addresses (CGA). This ensures that the claimed source of an NDP message is -// the owner of the claimed address. -func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { +// Validation of the confirmation is expected to be handled by the caller. +func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() @@ -309,3 +324,12 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { entry.mu.Unlock() } } + +func (n *neighborCache) nudConfig() (NUDConfigurations, tcpip.Error) { + return n.config(), nil +} + +func (n *neighborCache) setNUDConfig(c NUDConfigurations) tcpip.Error { + n.setConfig(c) + return nil +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 2870e4f66..b489b5e08 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -76,10 +76,15 @@ func entryDiffOptsWithSort() []cmp.Option { })) } -func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { +func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - neigh := &neighborCache{ + linkRes := &testNeighborResolver{ + clock: clock, + entries: newTestEntryStore(), + delay: typicalLatency, + } + linkRes.neigh = &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, @@ -88,11 +93,11 @@ func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock id: 1, stats: makeNICStats(), }, - state: NewNUDState(config, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + state: NewNUDState(config, rng), + linkRes: linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } - neigh.nic.neigh = neigh - return neigh + return linkRes } // testEntryStore contains a set of IP to NeighborEntry mappings. @@ -194,7 +199,7 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { if !r.dropReplies { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { @@ -212,7 +217,7 @@ func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ // fakeRequest emulates handling a response for a link address request. func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { if entry, ok := r.entries.entryByAddr(addr); ok { - r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + r.neigh.handleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, @@ -242,10 +247,10 @@ func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. @@ -260,14 +265,14 @@ func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) c.MinRandomFactor = 1 c.MaxRandomFactor = 1 - neigh.setConfig(c) + linkRes.neigh.setConfig(c) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. @@ -282,22 +287,15 @@ func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -329,8 +327,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -346,23 +344,16 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -394,7 +385,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - neigh.removeEntry(entry.Addr) + linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -417,17 +408,15 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } { - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } } } type testContext struct { clock *faketime.ManualClock - neigh *neighborCache - store *testEntryStore linkRes *testNeighborResolver nudDisp *testNUDDispatcher } @@ -435,19 +424,10 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(nudDisp, c, clock) return testContext{ clock: clock, - neigh: neigh, - store: store, linkRes: linkRes, nudDisp: nudDisp, } @@ -461,17 +441,17 @@ type overflowOptions struct { func (c *testContext) overflowCache(opts overflowOptions) error { // Fill the neighbor cache to capacity to verify the LRU eviction strategy is // working properly after the entry removal. - for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { // Add a new entry - entry, ok := c.store.entry(i) + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -479,9 +459,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // LRU eviction strategy. Note that the number of static entries should not // affect the total number of dynamic entries that can be added. if i >= neighborCacheSize+opts.startAtEntryIndex { - removedEntry, ok := c.store.entry(i - neighborCacheSize) + removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) if !ok { - return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, @@ -524,10 +504,10 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries - for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { - entry, ok := c.store.entry(i) + for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -537,7 +517,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -581,15 +561,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -618,7 +598,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { } // Remove the entry - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -657,12 +637,12 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -683,7 +663,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { } // Remove the static entry that was just added - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) // No more events should have been dispatched. c.nudDisp.mu.Lock() @@ -701,12 +681,12 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -728,7 +708,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a duplicate entry with a different link address staticLinkAddr += "duplicate" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -763,12 +743,12 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -789,7 +769,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { } // Remove the static entry that was just added - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ { @@ -833,13 +813,13 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -871,7 +851,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Override the entry with a static one using the same address staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -926,14 +906,14 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -941,7 +921,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -983,23 +963,16 @@ func TestNeighborCacheClear(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add a dynamic entry. - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -1031,7 +1004,7 @@ func TestNeighborCacheClear(t *testing.T) { } // Add a static entry. - neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) { wantEvents := []testEntryEventInfo{ @@ -1055,7 +1028,7 @@ func TestNeighborCacheClear(t *testing.T) { } // Clear should remove both dynamic and static entries. - neigh.clear() + linkRes.neigh.clear() // Remove events dispatched from clear() have no deterministic order so they // need to be sorted beforehand. @@ -1099,13 +1072,13 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1136,7 +1109,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { } // Clear the cache. - c.neigh.clear() + c.linkRes.neigh.clear() { wantEvents := []testEntryEventInfo{ { @@ -1175,18 +1148,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - frequentlyUsedEntry, ok := store.entry(0) + frequentlyUsedEntry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1194,23 +1160,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Fill the neighbor cache to capacity for i := 0; i < neighborCacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1241,38 +1207,38 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < store.size(); i++ { + for i := neighborCacheSize; i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) + if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err) } } - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy - removedEntry, ok := store.entry(i - neighborCacheSize + 1) + removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) if !ok { - t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1) } wantEvents := []testEntryEventInfo{ { @@ -1322,10 +1288,10 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } - for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1335,7 +1301,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -1354,26 +1320,19 @@ func TestNeighborCacheConcurrent(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - storeEntries := store.entries() + storeEntries := linkRes.entries.entries() for _, entry := range storeEntries { var wg sync.WaitGroup for r := 0; r < concurrentProcesses; r++ { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) { + switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) { case nil, *tcpip.ErrWouldBlock: default: - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) } }(entry) } @@ -1391,10 +1350,10 @@ func TestNeighborCacheConcurrent(t *testing.T) { // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry - for i := store.size() - neighborCacheSize; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Errorf("store.entry(%d) not found", i) + t.Errorf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1404,7 +1363,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1414,41 +1373,34 @@ func TestNeighborCacheReplace(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add an entry - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1459,21 +1411,21 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } // Notify of a link address change var updatedLinkAddr tcpip.LinkAddress { - entry, ok := store.entry(1) + entry, ok := linkRes.entries.entry(1) if !ok { - t.Fatal("store.entry(1) not found") + t.Fatal("linkRes.entries.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } - store.set(0, updatedLinkAddr) - neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + linkRes.entries.set(0, updatedLinkAddr) + linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, @@ -1483,9 +1435,9 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1493,17 +1445,17 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1511,7 +1463,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } } @@ -1521,46 +1473,39 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() + linkRes := newTestNeighborResolver(&nudDisp, config, clock) var requestCount uint32 - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - onLinkAddressRequest: func() { - atomic.AddUint32(&requestCount, 1) - }, + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) } - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1568,7 +1513,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } // Verify address resolution fails for an unknown address. @@ -1576,24 +1521,24 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - maxAttempts := neigh.config().MaxUnicastProbes + maxAttempts := linkRes.neigh.config().MaxUnicastProbes if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { t.Errorf("got link address request count = %d, want = %d", got, want) } @@ -1607,27 +1552,22 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config.RetransmitTimer = time.Millisecond // small enough to cause timeout clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: time.Minute, // large enough to cause timeout - } + linkRes := newTestNeighborResolver(nil, config, clock) + // large enough to cause timeout + linkRes.delay = time.Minute - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1635,7 +1575,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1644,31 +1584,24 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { func TestNeighborCacheRetryResolution(t *testing.T) { config := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - // Simulate a faulty link. - dropReplies: true, - } + linkRes := newTestNeighborResolver(nil, config, clock) + // Simulate a faulty link. + linkRes.dropReplies = true - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1676,7 +1609,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1688,20 +1621,20 @@ func TestNeighborCacheRetryResolution(t *testing.T) { State: Failed, }, } - if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) } // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) @@ -1713,9 +1646,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { if !ok { t.Fatal("expected successful address resolution") } - reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) } if reachableEntry.Addr != entry.Addr { t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) @@ -1727,7 +1660,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) } default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } } @@ -1736,42 +1669,36 @@ func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() clock := &tcpip.StdClock{} - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: 0, - } + linkRes := newTestNeighborResolver(nil, config, clock) + linkRes.delay = 0 // Clear for every possible size of the cache for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. for i := 0; i < cacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - b.Fatalf("store.entry(%d) not found", i) + b.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { case <-ch: default: - b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } b.StartTimer() - neigh.clear() + linkRes.neigh.clear() b.StopTimer() } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 53ac9bb6e..b05f96d4f 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -77,11 +77,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - - // linkRes provides the functionality to send reachability probes, used in - // Neighbor Unreachability Detection. - linkRes LinkAddressResolver + cache *neighborCache // nudState points to the Neighbor Unreachability Detection configuration. nudState *NUDState @@ -106,10 +102,9 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry { return &neighborEntry{ - nic: nic, - linkRes: linkRes, + cache: cache, nudState: nudState, neigh: NeighborEntry{ Addr: remoteAddr, @@ -121,18 +116,18 @@ func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, li // newStaticNeighborEntry creates a neighbor cache entry starting at the // Static state. The entry can only transition out of Static by directly // calling `setStateLocked`. -func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { +func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ Addr: addr, LinkAddr: linkAddr, State: Static, - UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), } - if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) + if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(cache.nic.id, entry) } return &neighborEntry{ - nic: nic, + cache: cache, nudState: state, neigh: entry, } @@ -158,7 +153,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // is resolved (which ends up obtaining the entry's lock) while holding the // link resolution queue's lock. Dequeuing packets in a new goroutine avoids // a lock ordering violation. - go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) + go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } @@ -167,8 +162,8 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh) } } @@ -177,8 +172,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh) } } @@ -187,8 +182,8 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh) } } @@ -206,7 +201,7 @@ func (e *neighborEntry) cancelJobLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchRemoveEventLocked() e.cancelJobLocked() e.notifyCompletionLocked(false /* succeeded */) @@ -222,7 +217,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { @@ -230,14 +225,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Stale) e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Probe) e.dispatchChangeEventLocked() }) @@ -254,14 +249,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -269,7 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(immediateDuration) case Failed: @@ -292,12 +287,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Failed: - e.nic.stats.Neighbor.FailedEntryLookups.Increment() + e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment() fallthrough case Unknown: e.neigh.State = Incomplete - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchAddEventLocked() @@ -340,7 +335,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // address SHOULD be placed in the IP Source Address of the outgoing // solicitation. // - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -350,7 +345,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -358,7 +353,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(immediateDuration) case Stale: @@ -504,7 +499,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 // here. - ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + ep, ok := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber] if !ok { panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 140b8ca00..57cfbdb8b 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { p := entryTestProbeInfo{ RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, @@ -230,22 +230,30 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e }, stats: makeNICStats(), } + netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), + header.IPv6ProtocolNumber: netEP, } rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) - linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) - + var linkRes entryTestLinkResolver // Stub out the neighbor cache to verify deletion from the cache. - nic.neigh = &neighborCache{ - nic: &nic, - state: nudState, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + neigh := &neighborCache{ + nic: &nic, + state: nudState, + linkRes: &linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + l := linkResolver{ + resolver: &linkRes, + neighborTable: neigh, + } + entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState) + neigh.cache[entryTestAddr1] = entry + nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{ + header.IPv6ProtocolNumber: l, } - nic.neigh.cache[entryTestAddr1] = entry return entry, &disp, &linkRes, clock } @@ -835,7 +843,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index e56a624fe..41a489047 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "math/rand" "reflect" "sync/atomic" @@ -25,8 +24,37 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +type neighborTable interface { + neighbors() ([]NeighborEntry, tcpip.Error) + addStaticEntry(tcpip.Address, tcpip.LinkAddress) + get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) + remove(tcpip.Address) tcpip.Error + removeAll() tcpip.Error + + handleProbe(tcpip.Address, tcpip.LinkAddress) + handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) + handleUpperLevelConfirmation(tcpip.Address) + + nudConfig() (NUDConfigurations, tcpip.Error) + setNUDConfig(NUDConfigurations) tcpip.Error +} + var _ NetworkInterface = (*NIC)(nil) +type linkResolver struct { + resolver LinkAddressResolver + + neighborTable neighborTable +} + +func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + return l.neighborTable.get(addr, localAddr, onResolve) +} + +func (l *linkResolver) confirmReachable(addr tcpip.Address) { + l.neighborTable.handleUpperLevelConfirmation(addr) +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -38,11 +66,11 @@ type NIC struct { context NICContext stats NICStats - neigh *neighborCache // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. - networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -53,8 +81,6 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution - linkAddrCache *linkAddrCache - mu struct { sync.RWMutex spoofing bool @@ -133,35 +159,18 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic := &NIC{ LinkEndpoint: ep, - stack: stack, - id: id, - name: name, - context: ctx, - stats: makeNICStats(), - networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + stack: stack, + id: id, + name: name, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver), } nic.linkResQueue.init(nic) - nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) - // Check for Neighbor Unreachability Detection support. - var nud NUDHandler - if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache { - rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) - nic.neigh = &neighborCache{ - nic: nic, - state: NewNUDState(stack.nudConfigs, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - - // An interface value that holds a nil pointer but non-nil type is not the - // same as the nil interface. Because of this, nud must only be assignd if - // nic.neigh is non-nil since a nil reference to a neighborCache is not - // valid. - // - // See https://golang.org/doc/faq#nil_error for more information. - nud = nic.neigh - } + resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { @@ -170,7 +179,32 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) + + netEP := netProto.NewEndpoint(nic, nic) + nic.networkEndpoints[netNum] = netEP + + if resolutionRequired { + if r, ok := netEP.(LinkAddressResolver); ok { + l := linkResolver{ + resolver: r, + } + + if stack.useNeighborCache { + l.neighborTable = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, stack.randomGenerator), + linkRes: r, + + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } else { + cache := new(linkAddrCache) + cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r) + l.neighborTable = cache + } + nic.linkAddrResolvers[r.LinkAddressProtocol()] = l + } + } } nic.LinkEndpoint.Attach(nic) @@ -223,18 +257,19 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() - } - // Clear the neighbour table (including static entries) as we cannot guarantee - // that the current neighbour table will be valid when the NIC is enabled - // again. - // - // This matches linux's behaviour at the time of writing: - // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 - switch err := n.clearNeighbors(); err.(type) { - case nil, *tcpip.ErrNotSupported: - default: - panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) + // Clear the neighbour table (including static entries) as we cannot + // guarantee that the current neighbour table will be valid when the NIC is + // enabled again. + // + // This matches linux's behaviour at the time of writing: + // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 + netProto := ep.NetworkProtocolNumber() + switch err := n.clearNeighbors(netProto); err.(type) { + case nil, *tcpip.ErrNotSupported: + default: + panic(fmt.Sprintf("n.clearNeighbors(%d): %s", netProto, err)) + } } if !n.setEnabled(false) { @@ -587,56 +622,52 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } -func (n *NIC) confirmReachable(addr tcpip.Address) { - if n := n.neigh; n != nil { - n.handleUpperLevelConfirmation(addr) +func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { + linkRes, ok := n.linkAddrResolvers[protocol] + if !ok { + return &tcpip.ErrNotSupported{} } -} -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - if n.neigh != nil { - entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) - return entry.LinkAddr, ch, err + if linkAddr, ok := linkRes.resolver.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil } - return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve) + _, _, err := linkRes.getNeighborLinkAddress(addr, localAddr, onResolve) + return err } -func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { - if n.neigh == nil { - return nil, &tcpip.ErrNotSupported{} +func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.neighbors() } - return n.neigh.entries(), nil + return nil, &tcpip.ErrNotSupported{} } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} +func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + linkRes.neighborTable.addStaticEntry(addr, linkAddress) + return nil } - n.neigh.addStaticEntry(addr, linkAddress) - return nil + return &tcpip.ErrNotSupported{} } -func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} +func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.remove(addr) } - if !n.neigh.removeEntry(addr) { - return &tcpip.ErrBadAddress{} - } - return nil + return &tcpip.ErrNotSupported{} } -func (n *NIC) clearNeighbors() tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} +func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.removeAll() } - n.neigh.clear() - return nil + return &tcpip.ErrNotSupported{} } // joinGroup adds a new endpoint for the given multicast address, if none @@ -880,9 +911,8 @@ func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt } } -// DeliverTransportControlPacket delivers control packets to the appropriate -// transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { +// DeliverTransportError implements TransportDispatcher. +func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -904,7 +934,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { + if n.stack.demux.deliverError(n, net, trans, transErr, pkt, id) { return } } @@ -920,24 +950,25 @@ func (n *NIC) Name() string { } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { - if n.neigh == nil { - return NUDConfigurations{}, &tcpip.ErrNotSupported{} +func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.nudConfig() } - return n.neigh.config(), nil + + return NUDConfigurations{}, &tcpip.ErrNotSupported{} } // setNUDConfigs sets the NUD configurations for n. // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} +func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + c.resetInvalidFields() + return linkRes.neighborTable.setNUDConfig(c) } - c.resetInvalidFields() - n.neigh.setConfig(c) - return nil + + return &tcpip.ErrNotSupported{} } func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { @@ -973,3 +1004,23 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RUnlock() return n.Enabled() && ep.IsAssigned(spoofing) } + +// HandleNeighborProbe implements NetworkInterface. +func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleProbe(addr, linkAddr) + return nil + } + + return &tcpip.ErrNotSupported{} +} + +// HandleNeighborConfirmation implements NetworkInterface. +func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleConfirmation(addr, linkAddr, flags) + return nil + } + + return &tcpip.ErrNotSupported{} +} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 2f719fbe5..9992d6eb4 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -111,8 +111,6 @@ type testIPv6EndpointStats struct{} // IsNetworkEndpointStats implements stack.NetworkEndpointStats. func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} -var _ LinkAddressResolver = (*testIPv6Protocol)(nil) - // We use this instead of ipv6.protocol because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Protocol struct{} @@ -139,7 +137,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ nic: nic, protocol: p, @@ -169,24 +167,6 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } -// LinkAddressProtocol implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -// LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { - return nil -} - -// ResolveStaticAddress implements LinkAddressResolver. -func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if header.IsV6MulticastAddress(addr) { - return header.EthernetAddressFromMulticastIPv6Address(addr), true - } - return "", false -} - func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 77926e289..5a94e9ac6 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -161,21 +161,6 @@ type ReachabilityConfirmationFlags struct { IsRouter bool } -// NUDHandler communicates external events to the Neighbor Unreachability -// Detection state machine, which is implemented per-interface. This is used by -// network endpoints to inform the Neighbor Cache of probes and confirmations. -type NUDHandler interface { - // HandleProbe processes an incoming neighbor probe (e.g. ARP request or - // Neighbor Solicitation for ARP or NDP, respectively). Validation of the - // probe needs to be performed before calling this function since the - // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) - - // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP - // reply or Neighbor Advertisement for ARP or NDP, respectively). - HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) -} - // NUDConfigurations is the NUD configurations for the netstack. This is used // by the neighbor cache to operate the NUD state machine on each device in the // local network. diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index ebfd5eb45..e9acef6a2 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -19,7 +19,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -52,83 +54,146 @@ func (f *fakeRand) Float32() float32 { return f.num } -// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NUD configurations using an invalid NICID. -func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The networking - // stack will only allocate neighbor caches if a protocol providing link - // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - UseNeighborCache: true, - }) +func TestNUDFunctions(t *testing.T) { + const nicID = 1 - // No NIC with ID 1 yet. - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(1, config) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{}) + tests := []struct { + name string + nicID tcpip.NICID + netProtoFactory []stack.NetworkProtocolFactory + extraLinkCapabilities stack.LinkEndpointCapabilities + expectedErr tcpip.Error + }{ + { + name: "Invalid NICID", + nicID: nicID + 1, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrUnknownNICID{}, + }, + { + name: "No network protocol", + nicID: nicID, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With resolution capability", + nicID: nicID, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6 and resolution capability", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + }, } -} -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to retrieve NUD configurations when the -// stack doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// a protocol providing link address resolution is specified (e.g. ARP, IPv6). -func TestNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, + NetworkProtocols: test.netProtoFactory, + Clock: clock, + }) - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channel.New(0, 0, linkAddr1) + e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired + e.LinkEPCapabilities |= test.extraLinkCapabilities - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - _, err := s.NUDConfigurations(nicID) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) - } -} + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to set NUD configurations when the stack -// doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// a protocol providing link address resolution is specified (e.g. ARP, IPv6). -func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 + configs := stack.DefaultNUDConfigurations() + configs.BaseReachableTime = time.Hour - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + { + err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + { + gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff(configs, gotConfigs); diff != "" { + t.Errorf("got configs mismatch (-want +got):\n%s", diff) + } + } + } - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(nicID, config) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) + for _, addr := range []tcpip.Address{llAddr1, llAddr2} { + { + err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff) + } + } + } + + { + wantErr := test.expectedErr + for i := 0; i < 2; i++ { + { + err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1) + if diff := cmp.Diff(wantErr, err); diff != "" { + t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } + + if test.expectedErr != nil { + break + } + + // Removing a neighbor that does not exist should give us a bad address + // error. + wantErr = &tcpip.ErrBadAddress{} + } + } + + { + neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Errorf("neighbors mismatch (-want +got):\n%s", diff) + } + } + } + + { + err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil { + t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err) + } else if len(neighbors) != 0 { + t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + } + } + }) } } -// TestDefaultNUDConfigurationIsValid verifies that calling -// resetInvalidFields() on the result of DefaultNUDConfigurations() does not -// change anything. DefaultNUDConfigurations() should return a valid -// NUDConfigurations. func TestDefaultNUDConfigurations(t *testing.T) { const nicID = 1 @@ -146,12 +211,12 @@ func TestDefaultNUDConfigurations(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - c, err := s.NUDConfigurations(nicID) + c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got, want := c, stack.DefaultNUDConfigurations(); got != want { - t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want) } } @@ -201,9 +266,9 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.BaseReachableTime; got != test.want { t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) @@ -258,9 +323,9 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MinRandomFactor; got != test.want { t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) @@ -338,9 +403,9 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxRandomFactor; got != test.want { t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) @@ -400,9 +465,9 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.RetransmitTimer; got != test.want { t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) @@ -452,9 +517,9 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.DelayFirstProbeTime; got != test.want { t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) @@ -504,9 +569,9 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxMulticastProbes; got != test.want { t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) @@ -556,9 +621,9 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxUnicastProbes; got != test.want { t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9d4fc3e48..4f013b212 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -187,6 +187,12 @@ func (pk *PacketBuffer) Size() int { return pk.HeaderSize() + pk.Data.Size() } +// MemSize returns the estimation size of the pk in memory, including backing +// buffer data. +func (pk *PacketBuffer) MemSize() int { + return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize +} + // Views returns the underlying storage of the whole packet. func (pk *PacketBuffer) Views() []buffer.View { // Optimization for outbound packets that headers are in pk.header. diff --git a/pkg/tcpip/stack/packet_buffer_unsafe.go b/pkg/tcpip/stack/packet_buffer_unsafe.go new file mode 100644 index 000000000..ee3d47270 --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_unsafe.go @@ -0,0 +1,19 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "unsafe" + +const packetBufferStructSize = int(unsafe.Sizeof(PacketBuffer{})) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 510da8689..d589f798d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -49,31 +49,6 @@ type TransportEndpointID struct { RemoteAddress tcpip.Address } -// ControlType is the type of network control message. -type ControlType int - -// The following are the allowed values for ControlType values. -// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. -const ( - // ControlAddressUnreachable indicates that an IPv6 packet did not reach its - // destination as the destination address was unreachable. - // - // This maps to the ICMPv6 Destination Ureachable Code 3 error; see - // RFC 4443 section 3.1 for more details. - ControlAddressUnreachable ControlType = iota - ControlNetworkUnreachable - // ControlNoRoute indicates that an IPv4 packet did not reach its destination - // because the destination host was unreachable. - // - // This maps to the ICMPv4 Destination Ureachable Code 1 error; see - // RFC 791's Destination Unreachable Message section (page 4) for more - // details. - ControlNoRoute - ControlPacketTooBig - ControlPortUnreachable - ControlUnknown -) - // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { // LocalAddressBroadcast is true if the packet's local address is a broadcast @@ -81,6 +56,39 @@ type NetworkPacketInfo struct { LocalAddressBroadcast bool } +// TransportErrorKind enumerates error types that are handled by the transport +// layer. +type TransportErrorKind int + +const ( + // PacketTooBigTransportError indicates that a packet did not reach its + // destination because a link on the path to the destination had an MTU that + // was too small to carry the packet. + PacketTooBigTransportError TransportErrorKind = iota + + // DestinationHostUnreachableTransportError indicates that the destination + // host was unreachable. + DestinationHostUnreachableTransportError + + // DestinationPortUnreachableTransportError indicates that a packet reached + // the destination host, but the transport protocol was not active on the + // destination port. + DestinationPortUnreachableTransportError + + // DestinationNetworkUnreachableTransportError indicates that the destination + // network was unreachable. + DestinationNetworkUnreachableTransportError +) + +// TransportError is a marker interface for errors that may be handled by the +// transport layer. +type TransportError interface { + tcpip.SockErrorCause + + // Kind returns the type of the transport error. + Kind() TransportErrorKind +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { @@ -93,10 +101,10 @@ type TransportEndpoint interface { // HandlePacket takes ownership of the packet. HandlePacket(TransportEndpointID, *PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g. - // ICMP) packets arrive to this transport endpoint. - // HandleControlPacket takes ownership of pkt. - HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer) + // HandleError is called when the transport endpoint receives an error. + // + // HandleError takes ownership of the packet buffer. + HandleError(TransportError, *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -248,14 +256,11 @@ type TransportDispatcher interface { // DeliverTransportPacket takes ownership of the packet. DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition - // DeliverTransportControlPacket delivers control packets to the - // appropriate transport protocol endpoint. - // - // pkt.NetworkHeader must be set before calling - // DeliverTransportControlPacket. + // DeliverTransportError delivers an error to the appropriate transport + // endpoint. // - // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) + // DeliverTransportError takes ownership of the packet buffer. + DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -530,6 +535,17 @@ type NetworkInterface interface { // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + + // HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP + // request or NDP Neighbor Solicitation). + // + // HandleNeighborProbe assumes that the probe is valid for the network + // interface the probe was received on. + HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error + + // HandleNeighborConfirmation processes an incoming neighbor confirmation + // (e.g. ARP reply or NDP Neighbor Advertisement). + HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) tcpip.Error } // LinkResolvableNetworkEndpoint handles link resolution events. @@ -649,7 +665,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint + NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -824,16 +840,12 @@ type InjectableLinkEndpoint interface { InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// A LinkAddressResolver is an extension to a NetworkProtocol that -// can resolve link addresses. +// A LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target // address. The request is broadcasted on the local network if a remote link // address is not provided. - // - // The request is sent from the passed network interface. If the interface - // local address is unspecified, any interface local address may be used. - LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the @@ -847,12 +859,6 @@ type LinkAddressResolver interface { LinkAddressProtocol() tcpip.NetworkProtocolNumber } -// A LinkAddressCache caches link addresses. -type LinkAddressCache interface { - // AddLinkAddress adds a link address to the cache. - AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress) -} - // RawFactory produces endpoints for writing various types of raw packets. type RawFactory interface { // NewUnassociatedEndpoint produces endpoints for writing packets not diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 4ae0f2a1a..bab55ce49 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -53,7 +53,7 @@ type Route struct { // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. - linkRes LinkAddressResolver + linkRes linkResolver } type routeInfo struct { @@ -174,7 +174,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA } if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { + if linkRes, ok := r.outgoingNIC.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes } } @@ -184,11 +184,11 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } - if r.linkRes == nil { + if r.linkRes.resolver == nil { return r } - if linkAddr, ok := r.linkRes.ResolveStaticAddress(r.RemoteAddress); ok { + if linkAddr, ok := r.linkRes.resolver.ResolveStaticAddress(r.RemoteAddress); ok { r.ResolveWith(linkAddr) return r } @@ -362,7 +362,7 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn } afterResolveFields := fields - linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { + linkAddr, ch, err := r.linkRes.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, func(r LinkResolutionResult) { if afterResolve != nil { if r.Success { afterResolveFields.RemoteLinkAddress = r.LinkAddress @@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -528,5 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool { // "Reachable" is defined as having full-duplex communication between the // local and remote ends of the route. func (r *Route) ConfirmReachable() { - r.outgoingNIC.confirmReachable(r.nextHop()) + if r.linkRes.resolver != nil { + r.linkRes.confirmReachable(r.nextHop()) + } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 119c4c505..57ad412a1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -376,7 +376,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { type Stack struct { transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol - linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. @@ -386,6 +385,15 @@ type Stack struct { stats tcpip.Stats + // LOCK ORDERING: mu > route.mu. + route struct { + mu struct { + sync.RWMutex + + table []tcpip.Route + } + } + mu sync.RWMutex nics map[tcpip.NICID]*NIC @@ -393,11 +401,6 @@ type Stack struct { cleanupEndpointsMu sync.Mutex cleanupEndpoints map[TransportEndpoint]struct{} - // route is the route table passed in by the user via SetRouteTable(), - // it is used by FindRoute() to build a route for a specific - // destination. - routeTable []tcpip.Route - *ports.PortManager // If not nil, then any new endpoints will have this probe function @@ -433,6 +436,8 @@ type Stack struct { // useNeighborCache indicates whether ARP and NDP packets should be handled // by the NIC's neighborCache instead of linkAddrCache. + // + // TODO(gvisor.dev/issue/4658): Remove this field. useNeighborCache bool // nudDisp is the NUD event dispatcher that is used to send the netstack @@ -499,13 +504,17 @@ type Options struct { // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations - // UseNeighborCache indicates whether ARP and NDP packets should be handled - // by the Neighbor Unreachability Detection (NUD) state machine. This flag - // also enables the APIs for inspecting and modifying the neighbor table via - // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor, - // and ClearNeighbors. + // UseNeighborCache is unused. + // + // TODO(gvisor.dev/issue/4658): Remove this field. UseNeighborCache bool + // UseLinkAddrCache indicates that the legacy link address cache should be + // used for link resolution. + // + // TODO(gvisor.dev/issue/4658): Remove this field. + UseLinkAddrCache bool + // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. NUDDisp NUDDispatcher @@ -635,7 +644,6 @@ func New(opts Options) *Stack { s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), cleanupEndpoints: make(map[TransportEndpoint]struct{}), PortManager: ports.NewPortManager(), @@ -646,7 +654,7 @@ func New(opts Options) *Stack { icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, - useNeighborCache: opts.UseNeighborCache, + useNeighborCache: !opts.UseLinkAddrCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), @@ -666,9 +674,6 @@ func New(opts Options) *Stack { for _, netProtoFactory := range opts.NetworkProtocols { netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto - if r, ok := netProto.(LinkAddressResolver); ok { - s.linkAddrResolvers[r.LinkAddressProtocol()] = r - } } // Add specified transport protocols. @@ -818,38 +823,37 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { // // This method takes ownership of the table. func (s *Stack) SetRouteTable(table []tcpip.Route) { - s.mu.Lock() - defer s.mu.Unlock() - - s.routeTable = table + s.route.mu.Lock() + defer s.route.mu.Unlock() + s.route.mu.table = table } // GetRouteTable returns the route table which is currently in use. func (s *Stack) GetRouteTable() []tcpip.Route { - s.mu.Lock() - defer s.mu.Unlock() - return append([]tcpip.Route(nil), s.routeTable...) + s.route.mu.RLock() + defer s.route.mu.RUnlock() + return append([]tcpip.Route(nil), s.route.mu.table...) } // AddRoute appends a route to the route table. func (s *Stack) AddRoute(route tcpip.Route) { - s.mu.Lock() - defer s.mu.Unlock() - s.routeTable = append(s.routeTable, route) + s.route.mu.Lock() + defer s.route.mu.Unlock() + s.route.mu.table = append(s.route.mu.table, route) } // RemoveRoutes removes matching routes from the route table. func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { - s.mu.Lock() - defer s.mu.Unlock() + s.route.mu.Lock() + defer s.route.mu.Unlock() var filteredRoutes []tcpip.Route - for _, route := range s.routeTable { + for _, route := range s.route.mu.table { if !match(route) { filteredRoutes = append(filteredRoutes, route) } } - s.routeTable = filteredRoutes + s.route.mu.table = filteredRoutes } // NewEndpoint creates a new transport layer endpoint of the given protocol. @@ -1022,17 +1026,18 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { delete(s.nics, id) // Remove routes in-place. n tracks the number of routes written. + s.route.mu.Lock() n := 0 - for i, r := range s.routeTable { - s.routeTable[i] = tcpip.Route{} + for i, r := range s.route.mu.table { + s.route.mu.table[i] = tcpip.Route{} if r.NIC != id { // Keep this route. - s.routeTable[n] = r + s.route.mu.table[n] = r n++ } } - - s.routeTable = s.routeTable[:n] + s.route.mu.table = s.route.mu.table[:n] + s.route.mu.Unlock() return nic.remove() } @@ -1357,39 +1362,49 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n // Find a route to the remote with the route table. var chosenRoute tcpip.Route - for _, route := range s.routeTable { - if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { - continue - } + if r := func() *Route { + s.route.mu.RLock() + defer s.route.mu.RUnlock() - nic, ok := s.nics[route.NIC] - if !ok || !nic.Enabled() { - continue - } + for _, route := range s.route.mu.table { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } - if id == 0 || id == route.NIC { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - var gateway tcpip.Address - if needRoute { - gateway = route.Gateway - } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) - if r == nil { - panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) + if r == nil { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r } - return r, nil } - } - // If the stack has forwarding enabled and we haven't found a valid route to - // the remote address yet, keep track of the first valid route. We keep - // iterating because we prefer routes that let us use a local address that - // is assigned to the outgoing interface. There is no requirement to do this - // from any RFC but simply a choice made to better follow a strong host - // model which the netstack follows at the time of writing. - if canForward && chosenRoute == (tcpip.Route{}) { - chosenRoute = route + // If the stack has forwarding enabled and we haven't found a valid route + // to the remote address yet, keep track of the first valid route. We + // keep iterating because we prefer routes that let us use a local + // address that is assigned to the outgoing interface. There is no + // requirement to do this from any RFC but simply a choice made to better + // follow a strong host model which the netstack follows at the time of + // writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } } + + return nil + }(); r != nil { + return r, nil } if chosenRoute != (tcpip.Route{}) { @@ -1517,20 +1532,6 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error { return nil } -// AddLinkAddress adds a link address for the neighbor on the specified NIC. -func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - nic, ok := s.nics[nicID] - if !ok { - return &tcpip.ErrUnknownNICID{} - } - - nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr) - return nil -} - // LinkResolutionResult is the result of a link address resolution attempt. type LinkResolutionResult struct { LinkAddress tcpip.LinkAddress @@ -1561,22 +1562,11 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, return &tcpip.ErrUnknownNICID{} } - linkRes, ok := s.linkAddrResolvers[protocol] - if !ok { - return &tcpip.ErrNotSupported{} - } - - if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { - onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) - return nil - } - - _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) - return err + return nic.getLinkAddress(addr, localAddr, protocol, onResolve) } // Neighbors returns all IP to MAC address associations. -func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { +func (s *Stack) Neighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1585,11 +1575,11 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { return nil, &tcpip.ErrUnknownNICID{} } - return nic.neighbors() + return nic.neighbors(protocol) } // AddStaticNeighbor statically associates an IP address to a MAC address. -func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1598,13 +1588,13 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd return &tcpip.ErrUnknownNICID{} } - return nic.addStaticNeighbor(addr, linkAddr) + return nic.addStaticNeighbor(addr, protocol, linkAddr) } // RemoveNeighbor removes an IP to MAC address association previously created // either automically or by AddStaticNeighbor. Returns ErrBadAddress if there // is no association with the provided address. -func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error { +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1613,11 +1603,11 @@ func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Erro return &tcpip.ErrUnknownNICID{} } - return nic.removeNeighbor(addr) + return nic.removeNeighbor(protocol, addr) } // ClearNeighbors removes all IP to MAC address associations. -func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { +func (s *Stack) ClearNeighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1626,7 +1616,7 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { return &tcpip.ErrUnknownNICID{} } - return nic.clearNeighbors() + return nic.clearNeighbors(protocol) } // RegisterTransportEndpoint registers the given endpoint with the stack @@ -1996,7 +1986,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco } // NUDConfigurations gets the per-interface NUD configurations. -func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) { +func (s *Stack) NUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() @@ -2005,14 +1995,14 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Erro return NUDConfigurations{}, &tcpip.ErrUnknownNICID{} } - return nic.nudConfigs() + return nic.nudConfigs(proto) } // SetNUDConfigurations sets the per-interface NUD configurations. // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error { +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() @@ -2021,7 +2011,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip. return &tcpip.ErrUnknownNICID{} } - return nic.setNUDConfigs(c) + return nic.setNUDConfigs(proto, c) } // Seed returns a 32 bit value that can be used as a seed value for port diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 41f95811f..b641a4aaa 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -137,12 +138,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket( + f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), - stack.ControlPortUnreachable, 0, pkt) + // Nothing checks the error. + nil, /* transport error */ + pkt, + ) return } @@ -243,7 +247,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ nic: nic, proto: f, @@ -4313,9 +4317,11 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, UseNeighborCache: true, + Clock: clock, }) e := channel.New(0, 0, "") e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -4323,36 +4329,56 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err) - } - if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err) + addrs := []struct { + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + proto: ipv4.ProtocolNumber, + addr: ipv4Addr, + }, + { + proto: ipv6.ProtocolNumber, + addr: ipv6Addr, + }, } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 2 { - t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err) + } + + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) + } } // Disabling the NIC should clear the neighbor table. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } // Enabling the NIC should have an empty neighbor table. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } } @@ -4391,7 +4417,9 @@ func TestStaticGetLinkAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, }) - if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + e := channel.New(0, 0, "") + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 26eceb804..7d8d0851e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -182,9 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } -// handleControlPacket delivers a control packet to the transport endpoint -// identified by id. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { +// handleError delivers an error to the transport endpoint identified by id. +func (epsByNIC *endpointsByNIC) handleError(n *NIC, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -200,7 +199,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -596,9 +595,11 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb return foundRaw } -// deliverControlPacket attempts to deliver the given control packet. Returns -// true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { +// deliverError attempts to deliver the given error to the appropriate transport +// endpoint. +// +// Returns true if the error was delivered. +func (d *transportDemuxer) deliverError(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -611,7 +612,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return false } - ep.handleControlPacket(n, id, typ, extra, pkt) + ep.handleError(n, id, transErr, pkt) return true } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index cf5de747b..bebf4e6b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -237,7 +237,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * f.acceptQueue = append(f.acceptQueue, ep) } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleError(stack.TransportError, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 7069352f2..f2301a9e6 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -247,6 +247,14 @@ func TestPing(t *testing.T) { } } +type transportError struct { + origin tcpip.SockErrOrigin + typ uint8 + code uint8 + info uint32 + kind stack.TransportErrorKind +} + func TestTCPLinkResolutionFailure(t *testing.T) { const ( host1NICID = 1 @@ -259,6 +267,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr tcpip.Address expectedWriteErr tcpip.Error sockError tcpip.SockError + transErr transportError }{ { name: "IPv4 with resolvable remote", @@ -278,10 +287,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr: ipv4Addr3.AddressWithPrefix.Address, expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - ErrType: byte(header.ICMPv4DstUnreachable), - ErrCode: byte(header.ICMPv4HostUnreachable), - ErrOrigin: tcpip.SockExtErrorOriginICMP, + Err: &tcpip.ErrNoRoute{}, Dst: tcpip.FullAddress{ NIC: host1NICID, Addr: ipv4Addr3.AddressWithPrefix.Address, @@ -293,6 +299,12 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }, NetProto: ipv4.ProtocolNumber, }, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4HostUnreachable), + kind: stack.DestinationHostUnreachableTransportError, + }, }, { name: "IPv6 without resolvable remote", @@ -300,10 +312,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr: ipv6Addr3.AddressWithPrefix.Address, expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - ErrType: byte(header.ICMPv6DstUnreachable), - ErrCode: byte(header.ICMPv6AddressUnreachable), - ErrOrigin: tcpip.SockExtErrorOriginICMP6, + Err: &tcpip.ErrNoRoute{}, Dst: tcpip.FullAddress{ NIC: host1NICID, Addr: ipv6Addr3.AddressWithPrefix.Address, @@ -315,6 +324,12 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }, NetProto: ipv6.ProtocolNumber, }, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6DstUnreachable), + code: uint8(header.ICMPv6AddressUnreachable), + kind: stack.DestinationHostUnreachableTransportError, + }, }, } @@ -393,9 +408,12 @@ func TestTCPLinkResolutionFailure(t *testing.T) { // are pre defined so we can simply compare pointers. return a == b }), - // Ignore the payload since we do not know the TCP seq/ack numbers. checker.IgnoreCmpPath( + // Ignore the payload since we do not know the TCP seq/ack numbers. "Payload", + // Ignore the cause since we will compare its properties separately + // since the concrete type of the cause is unknown. + "Cause", ), } @@ -407,6 +425,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { t.Errorf("socket error mismatch (-want +got):\n%s", diff) } + + transErr, ok := sockErr.Cause.(stack.TransportError) + if !ok { + t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause) + } + if diff := cmp.Diff( + test.transErr, + transportError{ + origin: transErr.Origin(), + typ: transErr.Type(), + code: transErr.Code(), + info: transErr.Info(), + kind: transErr.Kind(), + }, + cmp.AllowUnexported(transportError{}), + ); diff != "" { + t.Errorf("socket error mismatch (-want +got):\n%s", diff) + } }) } } @@ -1069,9 +1105,9 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // Wait for the remote's neighbor entry to be stale before creating a // TCP connection from host1 to some remote. - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID) + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err) + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) } // The maximum reachable time for a neighbor is some maximum random factor // applied to the base reachable time. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 3cf05520d..f5e1a6e45 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -778,9 +778,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { -} +// HandleError implements stack.TransportEndpoint. +func (*endpoint) HandleError(stack.TransportError, *stack.PacketBuffer) {} // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't // expose internal socket state. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6e4e26c39..e645aa194 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2683,7 +2683,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } -func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -2692,11 +2692,8 @@ func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extr // Update the error queue if IP_RECVERR is enabled. if e.SocketOptions().GetRecvError() { e.SocketOptions().QueueErr(&tcpip.SockError{ - Err: err, - ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), - ErrType: errType, - ErrCode: errCode, - ErrInfo: extra, + Err: err, + Cause: transErr, // Linux passes the payload with the TCP header. We don't know if the TCP // header even exists, it may not for fragmented packets. Payload: pkt.Data.ToView(), @@ -2718,27 +2715,26 @@ func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extr e.notifyProtocolGoroutine(notifyError) } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { - switch typ { - case stack.ControlPacketTooBig: +// HandleError implements stack.TransportEndpoint. +func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { + handlePacketTooBig := func(mtu uint32) { e.sndBufMu.Lock() e.packetTooBigCount++ - if v := int(extra); v < e.sndMTU { + if v := int(mtu); v < e.sndMTU { e.sndMTU = v } e.sndBufMu.Unlock() - e.notifyProtocolGoroutine(notifyMTUChanged) + } - case stack.ControlNoRoute: - e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) - - case stack.ControlAddressUnreachable: - e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) - - case stack.ControlNetworkUnreachable: - e.onICMPError(&tcpip.ErrNetworkUnreachable{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) + // TODO(gvisor.dev/issues/5270): Handle all transport errors. + switch transErr.Kind() { + case stack.PacketTooBigTransportError: + handlePacketTooBig(transErr.Info()) + case stack.DestinationHostUnreachableTransportError: + e.onICMPError(&tcpip.ErrNoRoute{}, transErr, pkt) + case stack.DestinationNetworkUnreachableTransportError: + e.onICMPError(&tcpip.ErrNetworkUnreachable{}, transErr, pkt) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 31a5ddce9..afd8f4d39 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1322,7 +1322,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -1338,12 +1338,9 @@ func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extr } e.SocketOptions().QueueErr(&tcpip.SockError{ - Err: err, - ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), - ErrType: errType, - ErrCode: errCode, - ErrInfo: extra, - Payload: payload, + Err: err, + Cause: transErr, + Payload: payload, Dst: tcpip.FullAddress{ NIC: pkt.NICID, Addr: e.ID.RemoteAddress, @@ -1362,24 +1359,13 @@ func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extr e.waiterQueue.Notify(waiter.EventErr) } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { - if typ == stack.ControlPortUnreachable { +// HandleError implements stack.TransportEndpoint. +func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { + // TODO(gvisor.dev/issues/5270): Handle all transport errors. + switch transErr.Kind() { + case stack.DestinationPortUnreachableTransportError: if e.EndpointState() == StateConnected { - var errType byte - var errCode byte - switch pkt.NetworkProtocolNumber { - case header.IPv4ProtocolNumber: - errType = byte(header.ICMPv4DstUnreachable) - errCode = byte(header.ICMPv4PortUnreachable) - case header.IPv6ProtocolNumber: - errType = byte(header.ICMPv6DstUnreachable) - errCode = byte(header.ICMPv6PortUnreachable) - default: - panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) - } - e.onICMPError(&tcpip.ErrConnectionRefused{}, errType, errCode, extra, pkt) - return + e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) } } } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 64c5298d3..5d81dbb94 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1614,7 +1614,7 @@ func TestTTL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{p}, }) - ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil) + ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) wantTTL = ep.DefaultTTL() ep.Close() } diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index d07ed6ba5..aaffabfd0 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -434,18 +434,7 @@ func TestTmpMount(t *testing.T) { // runsc to hide the incoherence of FDs opened before and after overlayfs // copy-up on the host. func TestHostOverlayfsCopyUp(t *testing.T) { - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/hostoverlaytest", - WorkDir: "/root", - }, "./test_copy_up"); err != nil { - t.Fatalf("docker run failed: %v", err) - } else if got != "" { - t.Errorf("test failed:\n%s", got) - } + runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o test_copy_up test_copy_up.c && ./test_copy_up") } // TestHostOverlayfsRewindDir tests that rewinddir() "causes the directory @@ -460,36 +449,14 @@ func TestHostOverlayfsCopyUp(t *testing.T) { // automated tests yield newly-added files from readdir() even if the fsgofer // does not explicitly rewinddir(), but overlayfs does not. func TestHostOverlayfsRewindDir(t *testing.T) { - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/hostoverlaytest", - WorkDir: "/root", - }, "./test_rewinddir"); err != nil { - t.Fatalf("docker run failed: %v", err) - } else if got != "" { - t.Errorf("test failed:\n%s", got) - } + runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o test_rewinddir test_rewinddir.c && ./test_rewinddir") } // Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it // cannot use tricks like userns as root. For this reason, run a basic link test // to ensure some coverage. func TestLink(t *testing.T) { - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/linktest", - WorkDir: "/root", - }, "./link_test"); err != nil { - t.Fatalf("docker run failed: %v", err) - } else if got != "" { - t.Errorf("test failed:\n%s", got) - } + runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o link_test link_test.c && ./link_test") } // This test ensures we can run ping without errors. @@ -500,17 +467,7 @@ func TestPing4Loopback(t *testing.T) { t.Skip("hostnet only supports TCP/UDP sockets, so ping is not supported.") } - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/ping4test", - }, "/root/ping4.sh"); err != nil { - t.Fatalf("docker run failed: %s", err) - } else if got != "" { - t.Errorf("test failed:\n%s", got) - } + runIntegrationTest(t, nil, "./ping4.sh") } // This test ensures we can enable ipv6 on loopback and run ping6 without @@ -522,20 +479,25 @@ func TestPing6Loopback(t *testing.T) { t.Skip("hostnet only supports TCP/UDP sockets, so ping6 is not supported.") } + // The CAP_NET_ADMIN capability is required to use the `ip` utility, which + // we use to enable ipv6 on loopback. + // + // By default, ipv6 loopback is not enabled by runsc, because docker does + // not assign an ipv6 address to the test container. + runIntegrationTest(t, []string{"NET_ADMIN"}, "./ping6.sh") +} + +func runIntegrationTest(t *testing.T, capAdd []string, args ...string) { ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/ping6test", - // The CAP_NET_ADMIN capability is required to use the `ip` utility, which - // we use to enable ipv6 on loopback. - // - // By default, ipv6 loopback is not enabled by runsc, because docker does - // not assign an ipv6 address to the test container. - CapAdd: []string{"NET_ADMIN"}, - }, "/root/ping6.sh"); err != nil { - t.Fatalf("docker run failed: %s", err) + Image: "basic/integrationtest", + WorkDir: "/root", + CapAdd: capAdd, + }, args...); err != nil { + t.Fatalf("docker run failed: %v", err) } else if got != "" { t.Errorf("test failed:\n%s", got) } diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go index b66e8f609..69275e54b 100644 --- a/test/packetimpact/tests/tcp_info_test.go +++ b/test/packetimpact/tests/tcp_info_test.go @@ -55,6 +55,9 @@ func TestTCPInfo(t *testing.T) { info := linux.TCPInfo{} infoBytes := dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } binary.Unmarshal(infoBytes, usermem.ByteOrder, &info) rtt := time.Duration(info.RTT) * time.Microsecond @@ -93,6 +96,9 @@ func TestTCPInfo(t *testing.T) { info = linux.TCPInfo{} infoBytes = dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } binary.Unmarshal(infoBytes, usermem.ByteOrder, &info) if info.CaState != linux.TCP_CA_Loss { t.Errorf("expected the connection to be in loss recovery, got: %v want: %v", info.CaState, linux.TCP_CA_Loss) diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 6ee2b73c1..e43f30ba3 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -557,6 +557,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:setgid_test", +) + +syscall_test( add_overlay = True, test = "//test/syscalls/linux:splice_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 2b4b6f348..80e2837f8 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -482,7 +482,9 @@ cc_binary( "//test/util:fs_util", "@com_google_absl//absl/strings", gtest, + "//test/util:logging", "//test/util:mount_util", + "//test/util:multiprocess_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -672,6 +674,7 @@ cc_binary( gtest, "//test/util:logging", "//test/util:memory_util", + "//test/util:multiprocess_util", "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", @@ -1380,6 +1383,7 @@ cc_binary( "//test/util:file_descriptor", "//test/util:fs_util", gtest, + "//test/util:posix_error", "//test/util:temp_path", "//test/util:temp_umask", "//test/util:test_main", @@ -2143,6 +2147,24 @@ cc_binary( ) cc_binary( + name = "setgid_test", + testonly = 1, + srcs = ["setgid.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:cleanup", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + gtest, + ], +) + +cc_binary( name = "splice_test", testonly = 1, srcs = ["splice.cc"], @@ -3827,6 +3849,8 @@ cc_binary( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", gtest, + "//test/util:cleanup", + "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", @@ -4082,6 +4106,7 @@ cc_binary( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", gtest, + "//test/util:cleanup", "//test/util:test_main", "//test/util:test_util", ], diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc index 85ec013d5..fab79d300 100644 --- a/test/syscalls/linux/chroot.cc +++ b/test/syscalls/linux/chroot.cc @@ -32,7 +32,9 @@ #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/logging.h" #include "test/util/mount_util.h" +#include "test/util/multiprocess_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -47,17 +49,20 @@ namespace { TEST(ChrootTest, Success) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); + const auto rest = [] { + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK_SUCCESS(chroot(temp_dir.path().c_str())); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } TEST(ChrootTest, PermissionDenied) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - // CAP_DAC_READ_SEARCH and CAP_DAC_OVERRIDE may override Execute permission on - // directories. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + // CAP_DAC_READ_SEARCH and CAP_DAC_OVERRIDE may override Execute permission + // on directories. + AutoCapability cap_search(CAP_DAC_READ_SEARCH, false); + AutoCapability cap_override(CAP_DAC_OVERRIDE, false); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); @@ -78,8 +83,10 @@ TEST(ChrootTest, NotExist) { } TEST(ChrootTest, WithoutCapability) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETPCAP))); + // Unset CAP_SYS_CHROOT. - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_CHROOT, false)); + AutoCapability cap(CAP_SYS_CHROOT, false); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallFailsWithErrno(EPERM)); @@ -97,51 +104,53 @@ TEST(ChrootTest, CreatesNewRoot) { auto file_in_new_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(new_root.path())); - // chroot into new_root. - ASSERT_THAT(chroot(new_root.path().c_str()), SyscallSucceeds()); - - // getcwd should return "(unreachable)" followed by the initial_cwd. - char cwd[1024]; - ASSERT_THAT(syscall(__NR_getcwd, cwd, sizeof(cwd)), SyscallSucceeds()); - std::string expected_cwd = "(unreachable)"; - expected_cwd += initial_cwd; - EXPECT_STREQ(cwd, expected_cwd.c_str()); - - // Should not be able to stat file by its full path. - struct stat statbuf; - EXPECT_THAT(stat(file_in_new_root.path().c_str(), &statbuf), - SyscallFailsWithErrno(ENOENT)); - - // Should be able to stat file at new rooted path. - auto basename = std::string(Basename(file_in_new_root.path())); - auto rootedFile = "/" + basename; - ASSERT_THAT(stat(rootedFile.c_str(), &statbuf), SyscallSucceeds()); - - // Should be able to stat cwd at '.' even though it's outside root. - ASSERT_THAT(stat(".", &statbuf), SyscallSucceeds()); - - // chdir into new root. - ASSERT_THAT(chdir("/"), SyscallSucceeds()); - - // getcwd should return "/". - EXPECT_THAT(syscall(__NR_getcwd, cwd, sizeof(cwd)), SyscallSucceeds()); - EXPECT_STREQ(cwd, "/"); - - // Statting '.', '..', '/', and '/..' all return the same dev and inode. - struct stat statbuf_dot; - ASSERT_THAT(stat(".", &statbuf_dot), SyscallSucceeds()); - struct stat statbuf_dotdot; - ASSERT_THAT(stat("..", &statbuf_dotdot), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_dotdot.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_dotdot.st_ino); - struct stat statbuf_slash; - ASSERT_THAT(stat("/", &statbuf_slash), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_slash.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_slash.st_ino); - struct stat statbuf_slashdotdot; - ASSERT_THAT(stat("/..", &statbuf_slashdotdot), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_slashdotdot.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_slashdotdot.st_ino); + const auto rest = [&] { + // chroot into new_root. + TEST_CHECK_SUCCESS(chroot(new_root.path().c_str())); + + // getcwd should return "(unreachable)" followed by the initial_cwd. + char cwd[1024]; + TEST_CHECK_SUCCESS(syscall(__NR_getcwd, cwd, sizeof(cwd))); + std::string expected_cwd = "(unreachable)"; + expected_cwd += initial_cwd; + TEST_CHECK(strcmp(cwd, expected_cwd.c_str()) == 0); + + // Should not be able to stat file by its full path. + struct stat statbuf; + TEST_CHECK_ERRNO(stat(file_in_new_root.path().c_str(), &statbuf), ENOENT); + + // Should be able to stat file at new rooted path. + auto basename = std::string(Basename(file_in_new_root.path())); + auto rootedFile = "/" + basename; + TEST_CHECK_SUCCESS(stat(rootedFile.c_str(), &statbuf)); + + // Should be able to stat cwd at '.' even though it's outside root. + TEST_CHECK_SUCCESS(stat(".", &statbuf)); + + // chdir into new root. + TEST_CHECK_SUCCESS(chdir("/")); + + // getcwd should return "/". + TEST_CHECK_SUCCESS(syscall(__NR_getcwd, cwd, sizeof(cwd))); + TEST_CHECK_SUCCESS(strcmp(cwd, "/") == 0); + + // Statting '.', '..', '/', and '/..' all return the same dev and inode. + struct stat statbuf_dot; + TEST_CHECK_SUCCESS(stat(".", &statbuf_dot)); + struct stat statbuf_dotdot; + TEST_CHECK_SUCCESS(stat("..", &statbuf_dotdot)); + TEST_CHECK(statbuf_dot.st_dev == statbuf_dotdot.st_dev); + TEST_CHECK(statbuf_dot.st_ino == statbuf_dotdot.st_ino); + struct stat statbuf_slash; + TEST_CHECK_SUCCESS(stat("/", &statbuf_slash)); + TEST_CHECK(statbuf_dot.st_dev == statbuf_slash.st_dev); + TEST_CHECK(statbuf_dot.st_ino == statbuf_slash.st_ino); + struct stat statbuf_slashdotdot; + TEST_CHECK_SUCCESS(stat("/..", &statbuf_slashdotdot)); + TEST_CHECK(statbuf_dot.st_dev == statbuf_slashdotdot.st_dev); + TEST_CHECK(statbuf_dot.st_ino == statbuf_slashdotdot.st_ino); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } TEST(ChrootTest, DotDotFromOpenFD) { @@ -152,18 +161,20 @@ TEST(ChrootTest, DotDotFromOpenFD) { Open(dir_outside_root.path(), O_RDONLY | O_DIRECTORY)); auto new_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - // chroot into new_root. - ASSERT_THAT(chroot(new_root.path().c_str()), SyscallSucceeds()); + const auto rest = [&] { + // chroot into new_root. + TEST_CHECK_SUCCESS(chroot(new_root.path().c_str())); - // openat on fd with path .. will succeed. - int other_fd; - ASSERT_THAT(other_fd = openat(fd.get(), "..", O_RDONLY), SyscallSucceeds()); - EXPECT_THAT(close(other_fd), SyscallSucceeds()); + // openat on fd with path .. will succeed. + int other_fd; + TEST_CHECK_SUCCESS(other_fd = openat(fd.get(), "..", O_RDONLY)); + TEST_CHECK_SUCCESS(close(other_fd)); - // getdents on fd should not error. - char buf[1024]; - ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)), - SyscallSucceeds()); + // getdents on fd should not error. + char buf[1024]; + TEST_CHECK_SUCCESS(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf))); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // Test that link resolution in a chroot can escape the root by following an @@ -179,24 +190,27 @@ TEST(ChrootTest, ProcFdLinkResolutionInChroot) { const FileDescriptor proc_fd = ASSERT_NO_ERRNO_AND_VALUE( Open("/proc", O_DIRECTORY | O_RDONLY | O_CLOEXEC)); - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Opening relative to an already open fd to a node outside the chroot works. - const FileDescriptor proc_self_fd = ASSERT_NO_ERRNO_AND_VALUE( - OpenAt(proc_fd.get(), "self/fd", O_DIRECTORY | O_RDONLY | O_CLOEXEC)); - - // Proc fd symlinks can escape the chroot if the fd the symlink refers to - // refers to an object outside the chroot. - struct stat s = {}; - EXPECT_THAT( - fstatat(proc_self_fd.get(), absl::StrCat(fd.get()).c_str(), &s, 0), - SyscallSucceeds()); - - // Try to stat the stdin fd. Internally, this is handled differently from a - // proc fd entry pointing to a file, since stdin is backed by a host fd, and - // isn't a walkable path on the filesystem inside the sandbox. - EXPECT_THAT(fstatat(proc_self_fd.get(), "0", &s, 0), SyscallSucceeds()); + const auto rest = [&] { + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK_SUCCESS(chroot(temp_dir.path().c_str())); + + // Opening relative to an already open fd to a node outside the chroot + // works. + const FileDescriptor proc_self_fd = TEST_CHECK_NO_ERRNO_AND_VALUE( + OpenAt(proc_fd.get(), "self/fd", O_DIRECTORY | O_RDONLY | O_CLOEXEC)); + + // Proc fd symlinks can escape the chroot if the fd the symlink refers to + // refers to an object outside the chroot. + struct stat s = {}; + TEST_CHECK_SUCCESS( + fstatat(proc_self_fd.get(), absl::StrCat(fd.get()).c_str(), &s, 0)); + + // Try to stat the stdin fd. Internally, this is handled differently from a + // proc fd entry pointing to a file, since stdin is backed by a host fd, and + // isn't a walkable path on the filesystem inside the sandbox. + TEST_CHECK_SUCCESS(fstatat(proc_self_fd.get(), "0", &s, 0)); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // This test will verify that when you hold a fd to proc before entering @@ -209,28 +223,30 @@ TEST(ChrootTest, ProcMemSelfFdsNoEscapeProcOpen) { const FileDescriptor proc = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - // Create and enter a chroot directory. - const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Open a file inside the chroot at /foo. - const FileDescriptor foo = - ASSERT_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); - - // Examine /proc/self/fd/{foo_fd} to see if it exposes the fact that we're - // inside a chroot, the path should be /foo and NOT {chroot_dir}/foo. - const std::string fd_path = absl::StrCat("self/fd/", foo.get()); - char buf[1024] = {}; - size_t bytes_read = 0; - ASSERT_THAT(bytes_read = - readlinkat(proc.get(), fd_path.c_str(), buf, sizeof(buf) - 1), - SyscallSucceeds()); - - // The link should resolve to something. - ASSERT_GT(bytes_read, 0); - - // Assert that the link doesn't contain the chroot path and is only /foo. - EXPECT_STREQ(buf, "/foo"); + const auto rest = [&] { + // Create and enter a chroot directory. + const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK_SUCCESS(chroot(temp_dir.path().c_str())); + + // Open a file inside the chroot at /foo. + const FileDescriptor foo = + TEST_CHECK_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); + + // Examine /proc/self/fd/{foo_fd} to see if it exposes the fact that we're + // inside a chroot, the path should be /foo and NOT {chroot_dir}/foo. + const std::string fd_path = absl::StrCat("self/fd/", foo.get()); + char buf[1024] = {}; + size_t bytes_read = 0; + TEST_CHECK_SUCCESS(bytes_read = readlinkat(proc.get(), fd_path.c_str(), buf, + sizeof(buf) - 1)); + + // The link should resolve to something. + TEST_CHECK(bytes_read > 0); + + // Assert that the link doesn't contain the chroot path and is only /foo. + TEST_CHECK(strcmp(buf, "/foo") == 0); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // This test will verify that a file inside a chroot when mmapped will not @@ -242,39 +258,41 @@ TEST(ChrootTest, ProcMemSelfMapsNoEscapeProcOpen) { const FileDescriptor proc = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - // Create and enter a chroot directory. - const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Open a file inside the chroot at /foo. - const FileDescriptor foo = - ASSERT_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); - - // Mmap the newly created file. - void* foo_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - foo.get(), 0); - ASSERT_THAT(reinterpret_cast<int64_t>(foo_map), SyscallSucceeds()); - - // Always unmap. - auto cleanup_map = Cleanup( - [&] { EXPECT_THAT(munmap(foo_map, kPageSize), SyscallSucceeds()); }); - - // Examine /proc/self/maps to be sure that /foo doesn't appear to be - // mapped with the full chroot path. - const FileDescriptor maps = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), "self/maps", O_RDONLY)); - - size_t bytes_read = 0; - char buf[8 * 1024] = {}; - ASSERT_THAT(bytes_read = ReadFd(maps.get(), buf, sizeof(buf)), - SyscallSucceeds()); - - // The maps file should have something. - ASSERT_GT(bytes_read, 0); - - // Finally we want to make sure the maps don't contain the chroot path - ASSERT_EQ(std::string(buf, bytes_read).find(temp_dir.path()), - std::string::npos); + const auto rest = [&] { + // Create and enter a chroot directory. + const auto temp_dir = TEST_CHECK_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK_SUCCESS(chroot(temp_dir.path().c_str())); + + // Open a file inside the chroot at /foo. + const FileDescriptor foo = + TEST_CHECK_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); + + // Mmap the newly created file. + void* foo_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE, foo.get(), 0); + TEST_CHECK_SUCCESS(reinterpret_cast<int64_t>(foo_map)); + + // Always unmap. + auto cleanup_map = + Cleanup([&] { TEST_CHECK_SUCCESS(munmap(foo_map, kPageSize)); }); + + // Examine /proc/self/maps to be sure that /foo doesn't appear to be + // mapped with the full chroot path. + const FileDescriptor maps = TEST_CHECK_NO_ERRNO_AND_VALUE( + OpenAt(proc.get(), "self/maps", O_RDONLY)); + + size_t bytes_read = 0; + char buf[8 * 1024] = {}; + TEST_CHECK_SUCCESS(bytes_read = ReadFd(maps.get(), buf, sizeof(buf))); + + // The maps file should have something. + TEST_CHECK(bytes_read > 0); + + // Finally we want to make sure the maps don't contain the chroot path + TEST_CHECK(std::string(buf, bytes_read).find(temp_dir.path()) == + std::string::npos); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // Test that mounts outside the chroot will not appear in /proc/self/mounts or @@ -283,81 +301,76 @@ TEST(ChrootTest, ProcMountsMountinfoNoEscape) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - // We are going to create some mounts and then chroot. In order to be able to - // unmount the mounts after the test run, we must chdir to the root and use - // relative paths for all mounts. That way, as long as we never chdir into - // the new root, we can access the mounts via relative paths and unmount them. - ASSERT_THAT(chdir("/"), SyscallSucceeds()); - - // Create nested tmpfs mounts. Note the use of relative paths in Mount calls. + // Create nested tmpfs mounts. auto const outer_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const outer_mount = ASSERT_NO_ERRNO_AND_VALUE(Mount( - "none", JoinPath(".", outer_dir.path()), "tmpfs", 0, "mode=0700", 0)); + auto const outer_mount = ASSERT_NO_ERRNO_AND_VALUE( + Mount("none", outer_dir.path(), "tmpfs", 0, "mode=0700", 0)); auto const inner_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(outer_dir.path())); - auto const inner_mount = ASSERT_NO_ERRNO_AND_VALUE(Mount( - "none", JoinPath(".", inner_dir.path()), "tmpfs", 0, "mode=0700", 0)); - - // Filenames that will be checked for mounts, all relative to /proc dir. - std::string paths[3] = {"mounts", "self/mounts", "self/mountinfo"}; - - for (const std::string& path : paths) { - // We should have both inner and outer mounts. - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContents(JoinPath("/proc", path))); - EXPECT_THAT(contents, AllOf(HasSubstr(outer_dir.path()), - HasSubstr(inner_dir.path()))); - // We better have at least two mounts: the mounts we created plus the root. - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_GT(submounts.size(), 2); - } - - // Get a FD to /proc before we enter the chroot. - const FileDescriptor proc = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - - // Chroot to outer mount. - ASSERT_THAT(chroot(outer_dir.path().c_str()), SyscallSucceeds()); - - for (const std::string& path : paths) { - const FileDescriptor proc_file = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); - - // Only two mounts visible from this chroot: the inner and outer. Both - // paths should be relative to the new chroot. - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); - EXPECT_THAT(contents, - AllOf(HasSubstr(absl::StrCat(Basename(inner_dir.path()))), - Not(HasSubstr(outer_dir.path())), - Not(HasSubstr(inner_dir.path())))); - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_EQ(submounts.size(), 2); - } - - // Chroot to inner mount. We must use an absolute path accessible to our - // chroot. - const std::string inner_dir_basename = - absl::StrCat("/", Basename(inner_dir.path())); - ASSERT_THAT(chroot(inner_dir_basename.c_str()), SyscallSucceeds()); - - for (const std::string& path : paths) { - const FileDescriptor proc_file = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); - - // Only the inner mount visible from this chroot. - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_EQ(submounts.size(), 1); - } - - // Chroot back to ".". - ASSERT_THAT(chroot("."), SyscallSucceeds()); + auto const inner_mount = ASSERT_NO_ERRNO_AND_VALUE( + Mount("none", inner_dir.path(), "tmpfs", 0, "mode=0700", 0)); + + const auto rest = [&outer_dir, &inner_dir] { + // Filenames that will be checked for mounts, all relative to /proc dir. + std::string paths[3] = {"mounts", "self/mounts", "self/mountinfo"}; + + for (const std::string& path : paths) { + // We should have both inner and outer mounts. + const std::string contents = + TEST_CHECK_NO_ERRNO_AND_VALUE(GetContents(JoinPath("/proc", path))); + EXPECT_THAT(contents, AllOf(HasSubstr(outer_dir.path()), + HasSubstr(inner_dir.path()))); + // We better have at least two mounts: the mounts we created plus the + // root. + std::vector<absl::string_view> submounts = + absl::StrSplit(contents, '\n', absl::SkipWhitespace()); + TEST_CHECK(submounts.size() > 2); + } + + // Get a FD to /proc before we enter the chroot. + const FileDescriptor proc = + TEST_CHECK_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); + + // Chroot to outer mount. + TEST_CHECK_SUCCESS(chroot(outer_dir.path().c_str())); + + for (const std::string& path : paths) { + const FileDescriptor proc_file = + TEST_CHECK_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); + + // Only two mounts visible from this chroot: the inner and outer. Both + // paths should be relative to the new chroot. + const std::string contents = + TEST_CHECK_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); + EXPECT_THAT(contents, + AllOf(HasSubstr(absl::StrCat(Basename(inner_dir.path()))), + Not(HasSubstr(outer_dir.path())), + Not(HasSubstr(inner_dir.path())))); + std::vector<absl::string_view> submounts = + absl::StrSplit(contents, '\n', absl::SkipWhitespace()); + TEST_CHECK(submounts.size() == 2); + } + + // Chroot to inner mount. We must use an absolute path accessible to our + // chroot. + const std::string inner_dir_basename = + absl::StrCat("/", Basename(inner_dir.path())); + TEST_CHECK_SUCCESS(chroot(inner_dir_basename.c_str())); + + for (const std::string& path : paths) { + const FileDescriptor proc_file = + TEST_CHECK_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); + const std::string contents = + TEST_CHECK_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); + + // Only the inner mount visible from this chroot. + std::vector<absl::string_view> submounts = + absl::StrSplit(contents, '\n', absl::SkipWhitespace()); + TEST_CHECK(submounts.size() == 1); + } + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } } // namespace diff --git a/test/syscalls/linux/getrusage.cc b/test/syscalls/linux/getrusage.cc index 0e51d42a8..e84cbfdc3 100644 --- a/test/syscalls/linux/getrusage.cc +++ b/test/syscalls/linux/getrusage.cc @@ -23,6 +23,7 @@ #include "absl/time/time.h" #include "test/util/logging.h" #include "test/util/memory_util.h" +#include "test/util/multiprocess_util.h" #include "test/util/signal_util.h" #include "test/util/test_util.h" @@ -93,59 +94,66 @@ TEST(GetrusageTest, Grandchild) { // Verifies that processes ignoring SIGCHLD do not have updated child maxrss // updated. TEST(GetrusageTest, IgnoreSIGCHLD) { - struct sigaction sa; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa)); - pid_t pid = fork(); - if (pid == 0) { + const auto rest = [] { + struct sigaction sa; + sa.sa_handler = SIG_IGN; + sa.sa_flags = 0; + auto cleanup = TEST_CHECK_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa)); + pid_t pid = fork(); + if (pid == 0) { + struct rusage rusage_self; + TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); + // The child has consumed some memory. + TEST_CHECK(rusage_self.ru_maxrss != 0); + _exit(0); + } + TEST_CHECK_SUCCESS(pid); + int status; + TEST_CHECK_ERRNO(RetryEINTR(waitpid)(pid, &status, 0), ECHILD); struct rusage rusage_self; - TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); - // The child has consumed some memory. - TEST_CHECK(rusage_self.ru_maxrss != 0); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), - SyscallFailsWithErrno(ECHILD)); - struct rusage rusage_self; - ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds()); - struct rusage rusage_children; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds()); - // The parent has consumed some memory. - EXPECT_GT(rusage_self.ru_maxrss, 0); - // The child's maxrss should not have propagated up. - EXPECT_EQ(rusage_children.ru_maxrss, 0); + TEST_CHECK_SUCCESS(getrusage(RUSAGE_SELF, &rusage_self)); + struct rusage rusage_children; + TEST_CHECK_SUCCESS(getrusage(RUSAGE_CHILDREN, &rusage_children)); + // The parent has consumed some memory. + TEST_CHECK(rusage_self.ru_maxrss > 0); + // The child's maxrss should not have propagated up. + TEST_CHECK(rusage_children.ru_maxrss == 0); + }; + // Execute inside a forked process so that rusage_children is clean. + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // Verifies that zombie processes do not update their parent's maxrss. Only // reaped processes should do this. TEST(GetrusageTest, IgnoreZombie) { - pid_t pid = fork(); - if (pid == 0) { + const auto rest = [] { + pid_t pid = fork(); + if (pid == 0) { + struct rusage rusage_self; + TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); + struct rusage rusage_children; + TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0); + // The child has consumed some memory. + TEST_CHECK(rusage_self.ru_maxrss != 0); + // The child has no children of its own. + TEST_CHECK(rusage_children.ru_maxrss == 0); + _exit(0); + } + TEST_CHECK_SUCCESS(pid); + // Give the child time to exit. Because we don't call wait, the child should + // remain a zombie. + absl::SleepFor(absl::Seconds(5)); struct rusage rusage_self; - TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); + TEST_CHECK_SUCCESS(getrusage(RUSAGE_SELF, &rusage_self)); struct rusage rusage_children; - TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0); - // The child has consumed some memory. - TEST_CHECK(rusage_self.ru_maxrss != 0); - // The child has no children of its own. + TEST_CHECK_SUCCESS(getrusage(RUSAGE_CHILDREN, &rusage_children)); + // The parent has consumed some memory. + TEST_CHECK(rusage_self.ru_maxrss > 0); + // The child has consumed some memory, but hasn't been reaped. TEST_CHECK(rusage_children.ru_maxrss == 0); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - // Give the child time to exit. Because we don't call wait, the child should - // remain a zombie. - absl::SleepFor(absl::Seconds(5)); - struct rusage rusage_self; - ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds()); - struct rusage rusage_children; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds()); - // The parent has consumed some memory. - EXPECT_GT(rusage_self.ru_maxrss, 0); - // The child has consumed some memory, but hasn't been reaped. - EXPECT_EQ(rusage_children.ru_maxrss, 0); + }; + // Execute inside a forked process so that rusage_children is clean. + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } TEST(GetrusageTest, Wait4) { diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 733b17834..e65ffee8f 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -75,55 +75,52 @@ class OpenTest : public FileTest { }; TEST_F(OpenTest, OTrunc) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_TRUNC, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(open(dir.path().c_str(), O_TRUNC, 0666), SyscallFailsWithErrno(EISDIR)); } TEST_F(OpenTest, OTruncAndReadOnlyDir) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(open(dir.path().c_str(), O_TRUNC | O_RDONLY, 0666), SyscallFailsWithErrno(EISDIR)); } TEST_F(OpenTest, OTruncAndReadOnlyFile) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); - const FileDescriptor existing = - ASSERT_NO_ERRNO_AND_VALUE(Open(dirpath.c_str(), O_RDWR | O_CREAT, 0666)); - const FileDescriptor otrunc = ASSERT_NO_ERRNO_AND_VALUE( - Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666)); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto path = JoinPath(dir.path(), "foo"); + EXPECT_NO_ERRNO(Open(path, O_RDWR | O_CREAT, 0666)); + EXPECT_NO_ERRNO(Open(path, O_TRUNC | O_RDONLY, 0666)); } TEST_F(OpenTest, OCreateDirectory) { SKIP_IF(IsRunningWithVFS1()); - auto dirpath = GetAbsoluteTestTmpdir(); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Normal case: existing directory. - ASSERT_THAT(open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), + ASSERT_THAT(open(dir.path().c_str(), O_RDWR | O_CREAT, 0666), SyscallFailsWithErrno(EISDIR)); // Trailing separator on existing directory. - ASSERT_THAT(open(dirpath.append("/").c_str(), O_RDWR | O_CREAT, 0666), + ASSERT_THAT(open(dir.path().append("/").c_str(), O_RDWR | O_CREAT, 0666), SyscallFailsWithErrno(EISDIR)); // Trailing separator on non-existing directory. - ASSERT_THAT(open(JoinPath(dirpath, "non-existent").append("/").c_str(), + ASSERT_THAT(open(JoinPath(dir.path(), "non-existent").append("/").c_str(), O_RDWR | O_CREAT, 0666), SyscallFailsWithErrno(EISDIR)); // "." special case. - ASSERT_THAT(open(JoinPath(dirpath, ".").c_str(), O_RDWR | O_CREAT, 0666), + ASSERT_THAT(open(JoinPath(dir.path(), ".").c_str(), O_RDWR | O_CREAT, 0666), SyscallFailsWithErrno(EISDIR)); } TEST_F(OpenTest, MustCreateExisting) { - auto dirPath = GetAbsoluteTestTmpdir(); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Existing directory. - ASSERT_THAT(open(dirPath.c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), + ASSERT_THAT(open(dir.path().c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), SyscallFailsWithErrno(EEXIST)); // Existing file. - auto newFile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dirPath)); + auto newFile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); ASSERT_THAT(open(newFile.path().c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), SyscallFailsWithErrno(EEXIST)); } @@ -206,7 +203,8 @@ TEST_F(OpenTest, AtAbsPath) { } TEST_F(OpenTest, OpenNoFollowSymlink) { - const std::string link_path = JoinPath(GetAbsoluteTestTmpdir(), "link"); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string link_path = JoinPath(dir.path().c_str(), "link"); ASSERT_THAT(symlink(test_file_name_.c_str(), link_path.c_str()), SyscallSucceeds()); auto cleanup = Cleanup([link_path]() { @@ -227,8 +225,7 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) { // // We will then open tmp_folder/sym_folder/file with O_NOFOLLOW and it // should succeed as O_NOFOLLOW only applies to the final path component. - auto tmp_path = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir())); + auto tmp_path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto sym_path = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), tmp_path.path())); auto file_path = @@ -246,8 +243,7 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) { // // open("root/child/symlink/root/child/file") TEST_F(OpenTest, SymlinkRecurse) { - auto root = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir())); + auto root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); auto symlink = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo(child.path(), "../..")); @@ -481,12 +477,8 @@ TEST_F(OpenTest, CanTruncateWithStrangePermissions) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); const DisableSave ds; // Permissions are dropped. std::string path = NewTempAbsPath(); - int fd; // Create a file without user permissions. - EXPECT_THAT( // SAVE_BELOW - fd = open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); + EXPECT_NO_ERRNO(Open(path, O_CREAT | O_TRUNC | O_WRONLY, 055)); // Cannot open file because we are owner and have no permissions set. EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); @@ -495,8 +487,7 @@ TEST_F(OpenTest, CanTruncateWithStrangePermissions) { EXPECT_THAT(chmod(path.c_str(), 0755), SyscallSucceeds()); // Now we can open the file again. - EXPECT_THAT(fd = open(path.c_str(), O_RDWR), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); + EXPECT_NO_ERRNO(Open(path, O_RDWR)); } TEST_F(OpenTest, OpenNonDirectoryWithTrailingSlash) { diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 9d63782fb..f8fbea79e 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -22,6 +22,7 @@ #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/temp_umask.h" #include "test/util/test_util.h" @@ -31,85 +32,60 @@ namespace testing { namespace { TEST(CreateTest, TmpFile) { - int fd; - EXPECT_THAT(fd = open(JoinPath(GetAbsoluteTestTmpdir(), "a").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_NO_ERRNO(Open(JoinPath(dir.path(), "a"), O_RDWR | O_CREAT, 0666)); } TEST(CreateTest, ExistingFile) { - int fd; - EXPECT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto path = JoinPath(dir.path(), "ExistingFile"); + EXPECT_NO_ERRNO(Open(path, O_RDWR | O_CREAT, 0666)); + EXPECT_NO_ERRNO(Open(path, O_RDWR | O_CREAT, 0666)); } TEST(CreateTest, CreateAtFile) { - int dirfd; - EXPECT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_DIRECTORY, 0666), - SyscallSucceeds()); - EXPECT_THAT(openat(dirfd, "CreateAtFile", O_RDWR | O_CREAT, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY, 0666)); + EXPECT_THAT(openat(dirfd.get(), "CreateAtFile", O_RDWR | O_CREAT, 0666), SyscallSucceeds()); - EXPECT_THAT(close(dirfd), SyscallSucceeds()); } TEST(CreateTest, HonorsUmask_NoRandomSave) { const DisableSave ds; // file cannot be re-opened as writable. + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); TempUmask mask(0222); - int fd; - ASSERT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "UmaskedFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(dir.path(), "UmaskedFile"), O_RDWR | O_CREAT, 0666)); struct stat statbuf; - ASSERT_THAT(fstat(fd, &statbuf), SyscallSucceeds()); + ASSERT_THAT(fstat(fd.get(), &statbuf), SyscallSucceeds()); EXPECT_EQ(0444, statbuf.st_mode & 0777); - EXPECT_THAT(close(fd), SyscallSucceeds()); } TEST(CreateTest, CreateExclusively) { - std::string filename = NewTempAbsPath(); - - int fd; - ASSERT_THAT(fd = open(filename.c_str(), O_CREAT | O_RDWR, 0644), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(open(filename.c_str(), O_CREAT | O_EXCL | O_RDWR, 0644), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto path = JoinPath(dir.path(), "foo"); + EXPECT_NO_ERRNO(Open(path, O_CREAT | O_RDWR, 0644)); + EXPECT_THAT(open(path.c_str(), O_CREAT | O_EXCL | O_RDWR, 0644), SyscallFailsWithErrno(EEXIST)); } TEST(CreateTest, CreatWithOTrunc) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(open(dir.path().c_str(), O_CREAT | O_TRUNC, 0666), SyscallFailsWithErrno(EISDIR)); } TEST(CreateTest, CreatDirWithOTruncAndReadOnly) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(open(dir.path().c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), SyscallFailsWithErrno(EISDIR)); } TEST(CreateTest, CreatFileWithOTruncAndReadOnly) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); - int dirfd; - ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), - SyscallSucceeds()); - ASSERT_THAT(close(dirfd), SyscallSucceeds()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto path = JoinPath(dir.path(), "foo"); + ASSERT_NO_ERRNO(Open(path, O_RDWR | O_CREAT, 0666)); + ASSERT_NO_ERRNO(Open(path, O_CREAT | O_TRUNC | O_RDONLY, 0666)); } TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) { diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 662c6feb2..d61d94309 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -18,6 +18,7 @@ #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/test_util.h" @@ -341,6 +342,8 @@ TEST(ProcNetUnix, StreamSocketStateStateConnectedOnAccept) { int clientfd; ASSERT_THAT(clientfd = accept(sockets->first_fd(), nullptr, nullptr), SyscallSucceeds()); + auto cleanup = Cleanup( + [clientfd]() { ASSERT_THAT(close(clientfd), SyscallSucceeds()); }); // Find the entry for the accepted socket. UDS proc entries don't have a // remote address, so we distinguish the accepted socket from the listen diff --git a/test/syscalls/linux/setgid.cc b/test/syscalls/linux/setgid.cc new file mode 100644 index 000000000..bfd91ba4f --- /dev/null +++ b/test/syscalls/linux/setgid.cc @@ -0,0 +1,370 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <limits.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kDirmodeMask = 07777; +constexpr int kDirmodeSgid = S_ISGID | 0777; +constexpr int kDirmodeNoExec = S_ISGID | 0767; +constexpr int kDirmodeNoSgid = 0777; + +// Sets effective GID and returns a Cleanup that restores the original. +PosixErrorOr<Cleanup> Setegid(gid_t egid) { + gid_t old_gid = getegid(); + if (setegid(egid) < 0) { + return PosixError(errno, absl::StrFormat("setegid(%d)", egid)); + } + return Cleanup( + [old_gid]() { EXPECT_THAT(setegid(old_gid), SyscallSucceeds()); }); +} + +// Returns a pair of groups that the user is a member of. +PosixErrorOr<std::pair<gid_t, gid_t>> Groups() { + // See whether the user is a member of at least 2 groups. + std::vector<gid_t> groups(64); + for (; groups.size() <= NGROUPS_MAX; groups.resize(groups.size() * 2)) { + int ngroups = getgroups(groups.size(), groups.data()); + if (ngroups < 0 && errno == EINVAL) { + // Need a larger list. + continue; + } + if (ngroups < 0) { + return PosixError(errno, absl::StrFormat("getgroups(%d, %p)", + groups.size(), groups.data())); + } + if (ngroups >= 2) { + return std::pair<gid_t, gid_t>(groups[0], groups[1]); + } + // There aren't enough groups. + break; + } + + // If we're root in the root user namespace, we can set our GID to whatever we + // want. Try that before giving up. + constexpr gid_t kGID1 = 1111; + constexpr gid_t kGID2 = 2222; + auto cleanup1 = Setegid(kGID1); + if (!cleanup1.ok()) { + return cleanup1.error(); + } + auto cleanup2 = Setegid(kGID2); + if (!cleanup2.ok()) { + return cleanup2.error(); + } + return std::pair<gid_t, gid_t>(kGID1, kGID2); +} + +class SetgidDirTest : public ::testing::Test { + protected: + void SetUp() override { + original_gid_ = getegid(); + + // TODO(b/175325250): Enable when setgid directories are supported. + SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETGID))); + + temp_dir_ = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); + groups_ = ASSERT_NO_ERRNO_AND_VALUE(Groups()); + } + + void TearDown() override { + ASSERT_THAT(setegid(original_gid_), SyscallSucceeds()); + } + + void MkdirAsGid(gid_t gid, const std::string& path, mode_t mode) { + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(gid)); + ASSERT_THAT(mkdir(path.c_str(), mode), SyscallSucceeds()); + } + + PosixErrorOr<struct stat> Stat(const std::string& path) { + struct stat stats; + if (stat(path.c_str(), &stats) < 0) { + return PosixError(errno, absl::StrFormat("stat(%s, _)", path)); + } + return stats; + } + + PosixErrorOr<struct stat> Stat(const FileDescriptor& fd) { + struct stat stats; + if (fstat(fd.get(), &stats) < 0) { + return PosixError(errno, "fstat(_, _)"); + } + return stats; + } + + TempPath temp_dir_; + std::pair<gid_t, gid_t> groups_; + gid_t original_gid_; +}; + +// The control test. Files created with a given GID are owned by that group. +TEST_F(SetgidDirTest, Control) { + // Set group to G1 and create a directory. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, 0777)); + + // Set group to G2, create a file in g1owned, and confirm that G2 owns it. + ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "g2owned").c_str(), O_CREAT | O_RDWR, 0777)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); +} + +// Setgid directories cause created files to inherit GID. +TEST_F(SetgidDirTest, CreateFile) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeSgid)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); + + // Set group to G2, create a file, and confirm that G1 owns it. + ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "g2created").c_str(), O_CREAT | O_RDWR, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.first); +} + +// Setgid directories cause created directories to inherit GID. +TEST_F(SetgidDirTest, CreateDir) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeSgid)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); + + // Set group to G2, create a directory, confirm that G1 owns it, and that the + // setgid bit is enabled. + auto g2created = JoinPath(g1owned, "g2created"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.second, g2created, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); + EXPECT_EQ(stats.st_gid, groups_.first); + EXPECT_EQ(stats.st_mode & S_ISGID, S_ISGID); +} + +// Setgid directories with group execution disabled still cause GID inheritance. +TEST_F(SetgidDirTest, NoGroupExec) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeNoExec)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoExec), SyscallSucceeds()); + + // Set group to G2, create a directory, confirm that G2 owns it, and that the + // setgid bit is enabled. + auto g2created = JoinPath(g1owned, "g2created"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.second, g2created, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); + EXPECT_EQ(stats.st_gid, groups_.first); + EXPECT_EQ(stats.st_mode & S_ISGID, S_ISGID); +} + +// Setting the setgid bit on directories with an existing file does not change +// the file's group. +TEST_F(SetgidDirTest, OldFile) { + // Set group to G1 and create a directory. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeNoSgid)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoSgid), SyscallSucceeds()); + + // Set group to G2, create a file, confirm that G2 owns it. + ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "g2created").c_str(), O_CREAT | O_RDWR, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); + + // Enable setgid. + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); + + // Confirm that the file's group is still G2. + stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); +} + +// Setting the setgid bit on directories with an existing subdirectory does not +// change the subdirectory's group. +TEST_F(SetgidDirTest, OldDir) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeNoSgid)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoSgid), SyscallSucceeds()); + + // Set group to G2, create a directory, confirm that G2 owns it. + ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + auto g2created = JoinPath(g1owned, "g2created"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.second, g2created, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); + EXPECT_EQ(stats.st_gid, groups_.second); + + // Enable setgid. + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); + + // Confirm that the file's group is still G2. + stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); + EXPECT_EQ(stats.st_gid, groups_.second); +} + +// Chowning a file clears the setgid and setuid bits. +TEST_F(SetgidDirTest, ChownFileClears) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeMask)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeMask), SyscallSucceeds()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "newfile").c_str(), O_CREAT | O_RDWR, 0666)); + ASSERT_THAT(fchmod(fd.get(), 0777 | S_ISUID | S_ISGID), SyscallSucceeds()); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.first); + EXPECT_EQ(stats.st_mode & (S_ISUID | S_ISGID), S_ISUID | S_ISGID); + + // Change the owning group. + ASSERT_THAT(fchown(fd.get(), -1, groups_.second), SyscallSucceeds()); + + // The setgid and setuid bits should be cleared. + stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); + EXPECT_EQ(stats.st_mode & (S_ISUID | S_ISGID), 0); +} + +// Chowning a file with setgid enabled, but not the group exec bit, does not +// clear the setgid bit. Such files are mandatory locked. +TEST_F(SetgidDirTest, ChownNoExecFileDoesNotClear) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeNoExec)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoExec), SyscallSucceeds()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "newdir").c_str(), O_CREAT | O_RDWR, 0666)); + ASSERT_THAT(fchmod(fd.get(), 0766 | S_ISUID | S_ISGID), SyscallSucceeds()); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.first); + EXPECT_EQ(stats.st_mode & (S_ISUID | S_ISGID), S_ISUID | S_ISGID); + + // Change the owning group. + ASSERT_THAT(fchown(fd.get(), -1, groups_.second), SyscallSucceeds()); + + // Only the setuid bit is cleared. + stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); + EXPECT_EQ(stats.st_mode & (S_ISUID | S_ISGID), S_ISGID); +} + +// Chowning a directory with setgid enabled does not clear the bit. +TEST_F(SetgidDirTest, ChownDirDoesNotClear) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeMask)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeMask), SyscallSucceeds()); + + // Change the owning group. + ASSERT_THAT(chown(g1owned.c_str(), -1, groups_.second), SyscallSucceeds()); + + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g1owned)); + EXPECT_EQ(stats.st_gid, groups_.second); + EXPECT_EQ(stats.st_mode & kDirmodeMask, kDirmodeMask); +} + +struct FileModeTestcase { + std::string name; + mode_t mode; + mode_t result_mode; + + FileModeTestcase(const std::string& name, mode_t mode, mode_t result_mode) + : name(name), mode(mode), result_mode(result_mode) {} +}; + +class FileModeTest : public ::testing::TestWithParam<FileModeTestcase> {}; + +TEST_P(FileModeTest, WriteToFile) { + // TODO(b/175325250): Enable when setgid directories are supported. + SKIP_IF(IsRunningOnGvisor()); + + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); + auto path = JoinPath(temp_dir.path(), GetParam().name); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_CREAT | O_RDWR, 0666)); + ASSERT_THAT(fchmod(fd.get(), GetParam().mode), SyscallSucceeds()); + struct stat stats; + ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); + EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().mode); + + // For security reasons, writing to the file clears the SUID bit, and clears + // the SGID bit when the group executable bit is unset (which is not a true + // SGID binary). + constexpr char kInput = 'M'; + ASSERT_THAT(write(fd.get(), &kInput, sizeof(kInput)), + SyscallSucceedsWithValue(sizeof(kInput))); + + ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); + EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().result_mode); +} + +TEST_P(FileModeTest, TruncateFile) { + // TODO(b/175325250): Enable when setgid directories are supported. + SKIP_IF(IsRunningOnGvisor()); + + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); + auto path = JoinPath(temp_dir.path(), GetParam().name); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_CREAT | O_RDWR, 0666)); + ASSERT_THAT(fchmod(fd.get(), GetParam().mode), SyscallSucceeds()); + struct stat stats; + ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); + EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().mode); + + // For security reasons, truncating the file clears the SUID bit, and clears + // the SGID bit when the group executable bit is unset (which is not a true + // SGID binary). + ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds()); + + ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); + EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().result_mode); +} + +INSTANTIATE_TEST_SUITE_P( + FileModes, FileModeTest, + ::testing::ValuesIn<FileModeTestcase>( + {FileModeTestcase("normal file", 0777, 0777), + FileModeTestcase("setuid", S_ISUID | 0777, 00777), + FileModeTestcase("setgid", S_ISGID | 0777, 00777), + FileModeTestcase("setuid and setgid", S_ISUID | S_ISGID | 0777, 00777), + FileModeTestcase("setgid without exec", S_ISGID | 0767, + S_ISGID | 0767), + FileModeTestcase("setuid and setgid without exec", + S_ISGID | S_ISUID | 0767, S_ISGID | 0767)})); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index a73987a7e..2c8e5f6f3 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -65,6 +65,9 @@ TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) { SyscallSucceeds()); } +// Copied from include/net/tcp.h. +constexpr int TCP_CA_OPEN = 0; + TEST_P(TCPSocketPairTest, CheckTcpInfoFields) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -87,7 +90,7 @@ TEST_P(TCPSocketPairTest, CheckTcpInfoFields) { SyscallSucceeds()); // Validates the received tcp_info fields. - EXPECT_EQ(opt.tcpi_ca_state, 0); + EXPECT_EQ(opt.tcpi_ca_state, TCP_CA_OPEN); EXPECT_GT(opt.tcpi_snd_cwnd, 0); EXPECT_GT(opt.tcpi_rto, 0); } diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index 64d6d0b8f..4139a18d8 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -23,6 +23,8 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "test/util/capability_util.h" +#include "test/util/cleanup.h" +#include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -33,6 +35,16 @@ ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); ABSL_FLAG(int32_t, scratch_gid1, 65534, "first scratch GID"); ABSL_FLAG(int32_t, scratch_gid2, 65533, "second scratch GID"); +// Force use of syscall instead of glibc set*id() wrappers because we want to +// apply to the current task only. libc sets all threads in a process because +// "POSIX requires that all threads in a process share the same credentials." +#define setuid USE_SYSCALL_INSTEAD +#define setgid USE_SYSCALL_INSTEAD +#define setreuid USE_SYSCALL_INSTEAD +#define setregid USE_SYSCALL_INSTEAD +#define setresuid USE_SYSCALL_INSTEAD +#define setresgid USE_SYSCALL_INSTEAD + using ::testing::UnorderedElementsAreArray; namespace gvisor { @@ -137,21 +149,31 @@ TEST(UidGidRootTest, Setuid) { TEST(UidGidRootTest, Setgid) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - EXPECT_THAT(setgid(-1), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(syscall(SYS_setgid, -1), SyscallFailsWithErrno(EINVAL)); - const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); - ASSERT_THAT(setgid(gid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); + ScopedThread([&] { + const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); + EXPECT_THAT(syscall(SYS_setgid, gid), SyscallSucceeds()); + EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); + }); } TEST(UidGidRootTest, SetgidNotFromThreadGroupLeader) { +#pragma push_macro("allow_setgid") +#undef setgid + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); + int old_gid = getgid(); + auto clean = Cleanup([old_gid] { setgid(old_gid); }); + const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); // NOTE(b/64676707): Do setgid in a separate thread so that we can test if // info.si_pid is set correctly. ScopedThread([gid] { ASSERT_THAT(setgid(gid), SyscallSucceeds()); }); EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); + +#pragma pop_macro("allow_setgid") } TEST(UidGidRootTest, Setreuid) { @@ -159,27 +181,25 @@ TEST(UidGidRootTest, Setreuid) { // "Supplying a value of -1 for either the real or effective user ID forces // the system to leave that ID unchanged." - setreuid(2) - EXPECT_THAT(setreuid(-1, -1), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setreuid, -1, -1), SyscallSucceeds()); + EXPECT_NO_ERRNO(CheckUIDs(0, 0, 0)); // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. After calling - // setuid(non-zero-UID), there is no way to get root privileges back. + // process can still open files the test harness created before starting + // this test. Otherwise, the files are created by root (UID before the + // test), but cannot be opened by the `uid` set below after the test. After + // calling setuid(non-zero-UID), there is no way to get root privileges + // back. ScopedThread([&] { const uid_t ruid = absl::GetFlag(FLAGS_scratch_uid1); const uid_t euid = absl::GetFlag(FLAGS_scratch_uid2); - // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. posix threads, however, require that all - // threads have the same UIDs, so using the setuid wrapper sets all threads' - // real UID. EXPECT_THAT(syscall(SYS_setreuid, ruid, euid), SyscallSucceeds()); // "If the real user ID is set or the effective user ID is set to a value - // not equal to the previous real user ID, the saved set-user-ID will be set - // to the new effective user ID." - setreuid(2) + // not equal to the previous real user ID, the saved set-user-ID will be + // set to the new effective user ID." - setreuid(2) EXPECT_NO_ERRNO(CheckUIDs(ruid, euid, euid)); }); } @@ -187,13 +207,15 @@ TEST(UidGidRootTest, Setreuid) { TEST(UidGidRootTest, Setregid) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - EXPECT_THAT(setregid(-1, -1), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setregid, -1, -1), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0)); - const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1); - const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2); - ASSERT_THAT(setregid(rgid, egid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid)); + ScopedThread([&] { + const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1); + const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2); + ASSERT_THAT(syscall(SYS_setregid, rgid, egid), SyscallSucceeds()); + EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid)); + }); } TEST(UidGidRootTest, Setresuid) { @@ -201,23 +223,24 @@ TEST(UidGidRootTest, Setresuid) { // "If one of the arguments equals -1, the corresponding value is not // changed." - setresuid(2) - EXPECT_THAT(setresuid(-1, -1, -1), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setresuid, -1, -1, -1), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckUIDs(0, 0, 0)); // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. After calling - // setuid(non-zero-UID), there is no way to get root privileges back. + // process can still open files the test harness created before starting + // this test. Otherwise, the files are created by root (UID before the + // test), but cannot be opened by the `uid` set below after the test. After + // calling setuid(non-zero-UID), there is no way to get root privileges + // back. ScopedThread([&] { const uid_t ruid = 12345; const uid_t euid = 23456; const uid_t suid = 34567; // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. posix threads, however, require that all - // threads have the same UIDs, so using the setuid wrapper sets all threads' - // real UID. + // call to only apply to this task. posix threads, however, require that + // all threads have the same UIDs, so using the setuid wrapper sets all + // threads' real UID. EXPECT_THAT(syscall(SYS_setresuid, ruid, euid, suid), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckUIDs(ruid, euid, suid)); }); @@ -226,14 +249,16 @@ TEST(UidGidRootTest, Setresuid) { TEST(UidGidRootTest, Setresgid) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - EXPECT_THAT(setresgid(-1, -1, -1), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setresgid, -1, -1, -1), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0)); - const gid_t rgid = 12345; - const gid_t egid = 23456; - const gid_t sgid = 34567; - ASSERT_THAT(setresgid(rgid, egid, sgid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, sgid)); + ScopedThread([&] { + const gid_t rgid = 12345; + const gid_t egid = 23456; + const gid_t sgid = 34567; + ASSERT_THAT(syscall(SYS_setresgid, rgid, egid, sgid), SyscallSucceeds()); + EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, sgid)); + }); } TEST(UidGidRootTest, Setgroups) { @@ -254,14 +279,14 @@ TEST(UidGidRootTest, Setuid_prlimit) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); // Do seteuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. + // process can still open files the test harness created before starting + // this test. Otherwise, the files are created by root (UID before the + // test), but cannot be opened by the `uid` set below after the test. ScopedThread([&] { - // Use syscall instead of glibc setuid wrapper because we want this seteuid - // call to only apply to this task. POSIX threads, however, require that all - // threads have the same UIDs, so using the seteuid wrapper sets all - // threads' UID. + // Use syscall instead of glibc setuid wrapper because we want this + // seteuid call to only apply to this task. POSIX threads, however, + // require that all threads have the same UIDs, so using the seteuid + // wrapper sets all threads' UID. EXPECT_THAT(syscall(SYS_setreuid, -1, 65534), SyscallSucceeds()); // Despite the UID change, we should be able to get our own limits. diff --git a/test/util/capability_util.h b/test/util/capability_util.h index bb9ea1fe5..a03bc7e05 100644 --- a/test/util/capability_util.h +++ b/test/util/capability_util.h @@ -96,6 +96,19 @@ inline PosixError DropPermittedCapability(int cap) { PosixErrorOr<bool> CanCreateUserNamespace(); +class AutoCapability { + public: + AutoCapability(int cap, bool set) : cap_(cap), set_(set) { + EXPECT_NO_ERRNO(SetCapability(cap_, set_)); + } + + ~AutoCapability() { EXPECT_NO_ERRNO(SetCapability(cap_, !set_)); } + + private: + int cap_; + bool set_; +}; + } // namespace testing } // namespace gvisor #endif // GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_ diff --git a/test/util/logging.h b/test/util/logging.h index 9d224ea05..5c17f1233 100644 --- a/test/util/logging.h +++ b/test/util/logging.h @@ -96,6 +96,21 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, std::move(_expr_result).ValueOrDie(); \ }) +// cond must be greater or equal than 0. Used to test result of syscalls. +// +// This macro is async-signal-safe. +#define TEST_CHECK_SUCCESS(cond) TEST_PCHECK((cond) >= 0) + +// cond must be -1 and errno must match errno_value. Used to test errors from +// syscalls. +// +// This macro is async-signal-safe. +#define TEST_CHECK_ERRNO(cond, errno_value) \ + do { \ + TEST_PCHECK((cond) == -1); \ + TEST_PCHECK_MSG(errno == (errno_value), #cond " expected " #errno_value); \ + } while (0) + } // namespace testing } // namespace gvisor |