diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 23 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarder.go | 131 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarder_test.go | 638 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 342 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 155 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 180 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp.go | 463 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 900 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 253 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 78 | ||||
-rw-r--r-- | pkg/tcpip/stack/rand.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 60 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 590 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 438 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 243 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 34 |
19 files changed, 3635 insertions, 994 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 8febd54c8..5e963a4af 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -15,15 +15,34 @@ go_template_instance( }, ) +go_template_instance( + name = "packet_buffer_list", + out = "packet_buffer_list.go", + package = "stack", + prefix = "PacketBuffer", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*PacketBuffer", + "Linker": "*PacketBuffer", + }, +) + go_library( name = "stack", srcs = [ "dhcpv6configurationfromndpra_string.go", + "forwarder.go", "icmp_rate_limit.go", + "iptables.go", + "iptables_targets.go", + "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", "ndp.go", "nic.go", + "packet_buffer.go", + "packet_buffer_list.go", + "rand.go", "registration.go", "route.go", "stack.go", @@ -33,6 +52,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/ilist", + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", @@ -40,7 +60,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", @@ -64,7 +83,6 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", @@ -80,6 +98,7 @@ go_test( name = "stack_test", size = "small", srcs = [ + "forwarder_test.go", "linkaddrcache_test.go", "nic_test.go", ], diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go new file mode 100644 index 000000000..6b64cd37f --- /dev/null +++ b/pkg/tcpip/stack/forwarder.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // maxPendingResolutions is the maximum number of pending link-address + // resolutions. + maxPendingResolutions = 64 + maxPendingPacketsPerResolution = 256 +) + +type pendingPacket struct { + nic *NIC + route *Route + proto tcpip.NetworkProtocolNumber + pkt PacketBuffer +} + +type forwardQueue struct { + sync.Mutex + + // The packets to send once the resolver completes. + packets map[<-chan struct{}][]*pendingPacket + + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + cancelChans []chan struct{} +} + +func newForwardQueue() *forwardQueue { + return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} +} + +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { + shouldWait := false + + f.Lock() + packets, ok := f.packets[ch] + if !ok { + shouldWait = true + } + for len(packets) == maxPendingPacketsPerResolution { + p := packets[0] + packets = packets[1:] + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Release() + } + if l := len(packets); l >= maxPendingPacketsPerResolution { + panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) + } + f.packets[ch] = append(packets, &pendingPacket{ + nic: n, + route: r, + proto: protocol, + pkt: pkt, + }) + f.Unlock() + + if !shouldWait { + return + } + + // Wait for the link-address resolution to complete. + // Start a goroutine with a forwarding-cancel channel so that we can + // limit the maximum number of goroutines running concurrently. + cancel := f.newCancelChannel() + go func() { + cancelled := false + select { + case <-ch: + case <-cancel: + cancelled = true + } + + f.Lock() + packets := f.packets[ch] + delete(f.packets, ch) + f.Unlock() + + for _, p := range packets { + if cancelled { + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + } else if _, err := p.route.Resolve(nil); err != nil { + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + } else { + p.nic.forwardPacket(p.route, p.proto, p.pkt) + } + p.route.Release() + } + }() +} + +// newCancelChannel creates a channel that can cancel a pending forwarding +// activity. The oldest channel is closed if the number of open channels would +// exceed maxPendingResolutions. +func (f *forwardQueue) newCancelChannel() chan struct{} { + f.Lock() + defer f.Unlock() + + if len(f.cancelChans) == maxPendingResolutions { + ch := f.cancelChans[0] + f.cancelChans = f.cancelChans[1:] + close(ch) + } + if l := len(f.cancelChans); l >= maxPendingResolutions { + panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) + } + + ch := make(chan struct{}) + f.cancelChans = append(f.cancelChans, ch) + return ch +} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go new file mode 100644 index 000000000..c7c663498 --- /dev/null +++ b/pkg/tcpip/stack/forwarder_test.go @@ -0,0 +1,638 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +const ( + fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fwdTestNetHeaderLen = 12 + fwdTestNetDefaultPrefixLen = 8 + + // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match + // the MTU of loopback interfaces on linux systems. + fwdTestNetDefaultMTU = 65536 +) + +// fwdTestNetworkEndpoint is a network-layer protocol endpoint. +// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only +// use the first three: destination address, source address, and transport +// protocol. They're all one byte fields to simplify parsing. +type fwdTestNetworkEndpoint struct { + nicID tcpip.NICID + id NetworkEndpointID + prefixLen int + proto *fwdTestNetworkProtocol + dispatcher TransportDispatcher + ep LinkEndpoint +} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) +} + +func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { + return f.nicID +} + +func (f *fwdTestNetworkEndpoint) PrefixLen() int { + return f.prefixLen +} + +func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { + return 123 +} + +func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { + return &f.id +} + +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { + // Consume the network header. + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } + pkt.Data.TrimFront(fwdTestNetHeaderLen) + + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) +} + +func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { + return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { + return 0 +} + +func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { + return f.ep.Capabilities() +} + +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { + // Add the protocol's header to the packet and send it to the link + // endpoint. + b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b[0] = r.RemoteAddress[0] + b[1] = f.id.LocalAddress[0] + b[2] = byte(params.Protocol) + + return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { + panic("not implemented") +} + +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func (*fwdTestNetworkEndpoint) Close() {} + +// fwdTestNetworkProtocol is a network-layer protocol that implements Address +// resolution. +type fwdTestNetworkProtocol struct { + addrCache *linkAddrCache + addrResolveDelay time.Duration + onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) + onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) +} + +func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber +} + +func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { + return fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { + return fwdTestNetDefaultPrefixLen +} + +func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) +} + +func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { + return &fwdTestNetworkEndpoint{ + nicID: nicID, + id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + proto: f, + dispatcher: dispatcher, + ep: ep, + }, nil +} + +func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (f *fwdTestNetworkProtocol) Close() {} + +func (f *fwdTestNetworkProtocol) Wait() {} + +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { + if f.addrCache != nil && f.onLinkAddressResolved != nil { + time.AfterFunc(f.addrResolveDelay, func() { + f.onLinkAddressResolved(f.addrCache, addr) + }) + } + return nil +} + +func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if f.onResolveStaticAddress != nil { + return f.onResolveStaticAddress(addr) + } + return "", false +} + +func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber +} + +// fwdTestPacketInfo holds all the information about an outbound packet. +type fwdTestPacketInfo struct { + RemoteLinkAddress tcpip.LinkAddress + LocalLinkAddress tcpip.LinkAddress + Pkt PacketBuffer +} + +type fwdTestLinkEndpoint struct { + dispatcher NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + + // C is where outbound packets are queued. + C chan fwdTestPacketInfo +} + +// InjectInbound injects an inbound packet. +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) +} + +// InjectLinkAddr injects an inbound packet with a remote link address. +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *fwdTestLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *fwdTestLinkEndpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { + caps := LinkEndpointCapabilities(0) + return caps | CapabilityResolutionRequired +} + +// GSOMaxSize returns the maximum GSO packet size. +func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { + return 1 << 15 +} + +// MaxHeaderLength returns the maximum size of the link layer header. Given it +// doesn't have a header, it just returns 0. +func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { + p := fwdTestPacketInfo{ + RemoteLinkAddress: r.RemoteLinkAddress, + LocalLinkAddress: r.LocalLinkAddress, + Pkt: pkt, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// WritePackets stores outbound packets into the channel. +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.WritePacket(r, gso, protocol, *pkt) + n++ + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + p := fwdTestPacketInfo{ + Pkt: PacketBuffer{Data: vv}, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*fwdTestLinkEndpoint) Wait() {} + +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { + // Create a stack with the network protocol and two NICs. + s := New(Options{ + NetworkProtocols: []NetworkProtocol{proto}, + }) + + proto.addrCache = s.linkAddrCache + + // Enable forwarding. + s.SetForwarding(true) + + // NIC 1 has the link address "a", and added the network address 1. + ep1 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "a", + } + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC #1 failed:", err) + } + if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { + t.Fatal("AddAddress #1 failed:", err) + } + + // NIC 2 has the link address "b", and added the network address 2. + ep2 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "b", + } + if err := s.CreateNIC(2, ep2); err != nil { + t.Fatal("CreateNIC #2 failed:", err) + } + if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { + t.Fatal("AddAddress #2 failed:", err) + } + + // Route all packets to NIC 2. + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) + } + + return ep1, ep2 +} + +func TestForwardingWithStaticResolver(t *testing.T) { + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } + + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithFakeResolver(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any address will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithNoResolver(t *testing.T) { + // Create a network protocol without a resolver. + proto := &fwdTestNetworkProtocol{} + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + select { + case <-ep2.C: + t.Fatal("Packet should not be forwarded") + case <-time.After(time.Second): + } +} + +func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + } + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[0] = 4 + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Data.ToView() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Data.ToView() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} + +func TestForwardingWithFakeResolverManyPackets(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[0] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Data.ToView() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + // The first 5 packets should not be forwarded so the the + // sequemnce number should start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} + +func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[0] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + b := p.Pkt.Data.ToView() + if b[0] < 8 { + t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go new file mode 100644 index 000000000..6b91159d4 --- /dev/null +++ b/pkg/tcpip/stack/iptables.go @@ -0,0 +1,342 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// Table names. +const ( + TablenameNat = "nat" + TablenameMangle = "mangle" + TablenameFilter = "filter" +) + +// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +const ( + ChainNamePrerouting = "PREROUTING" + ChainNameInput = "INPUT" + ChainNameForward = "FORWARD" + ChainNameOutput = "OUTPUT" + ChainNamePostrouting = "POSTROUTING" +) + +// HookUnset indicates that there is no hook set for an entrypoint or +// underflow. +const HookUnset = -1 + +// DefaultTables returns a default set of tables. Each chain is set to accept +// all packets. +func DefaultTables() IPTables { + // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for + // iotas. + return IPTables{ + Tables: map[string]Table{ + TablenameNat: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, + }, + Underflows: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, + }, + UserChains: map[string]int{}, + }, + TablenameMangle: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: map[Hook]int{ + Prerouting: 0, + Output: 1, + }, + UserChains: map[string]int{}, + }, + TablenameFilter: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + Underflows: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + UserChains: map[string]int{}, + }, + }, + Priorities: map[Hook][]string{ + Input: []string{TablenameNat, TablenameFilter}, + Prerouting: []string{TablenameMangle, TablenameNat}, + Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + }, + } +} + +// EmptyFilterTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyFilterTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + Underflows: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + UserChains: map[string]int{}, + } +} + +// EmptyNatTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyNatTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + Underflows: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + UserChains: map[string]int{}, + } +} + +// A chainVerdict is what a table decides should be done with a packet. +type chainVerdict int + +const ( + // chainAccept indicates the packet should continue through netstack. + chainAccept chainVerdict = iota + + // chainAccept indicates the packet should be dropped. + chainDrop + + // chainReturn indicates the packet should return to the calling chain + // or the underflow rule of a builtin chain. + chainReturn +) + +// Check runs pkt through the rules for hook. It returns true when the packet +// should continue traversing the network stack and false when it should be +// dropped. +// +// Precondition: pkt.NetworkHeader is set. +func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { + // Go through each table containing the hook. + for _, tablename := range it.Priorities[hook] { + table := it.Tables[tablename] + ruleIdx := table.BuiltinChains[hook] + switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { + // If the table returns Accept, move on to the next table. + case chainAccept: + continue + // The Drop verdict is final. + case chainDrop: + return false + case chainReturn: + // Any Return from a built-in chain means we have to + // call the underflow. + underflow := table.Rules[table.Underflows[hook]] + switch v, _ := underflow.Target.Action(pkt); v { + case RuleAccept: + continue + case RuleDrop: + return false + case RuleJump, RuleReturn: + panic("Underflows should only return RuleAccept or RuleDrop.") + default: + panic(fmt.Sprintf("Unknown verdict: %d", v)) + } + + default: + panic(fmt.Sprintf("Unknown verdict %v.", verdict)) + } + } + + // Every table returned Accept. + return true +} + +// CheckPackets runs pkts through the rules for hook and returns a map of packets that +// should not go forward. +// +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// +// NOTE: unlike the Check API the returned map contains packets that should be +// dropped. +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if ok := it.Check(hook, *pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + } + return drop +} + +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { + // Start from ruleIdx and walk the list of rules until a rule gives us + // a verdict. + for ruleIdx < len(table.Rules) { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { + case RuleAccept: + return chainAccept + + case RuleDrop: + return chainDrop + + case RuleReturn: + return chainReturn + + case RuleJump: + // "Jumping" to the next rule just means we're + // continuing on down the list. + if jumpTo == ruleIdx+1 { + ruleIdx++ + continue + } + switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + case chainAccept: + return chainAccept + case chainDrop: + return chainDrop + case chainReturn: + ruleIdx++ + continue + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + } + + // We got through the entire table without a decision. Default to DROP + // for safety. + return chainDrop +} + +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { + rule := table.Rules[ruleIdx] + + // If pkt.NetworkHeader hasn't been set yet, it will be contained in + // pkt.Data. + if pkt.NetworkHeader == nil { + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } + } + + // Check whether the packet matches the IP header filter. + if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + + // Go through each rule matcher. If they all match, run + // the rule target. + for _, matcher := range rule.Matchers { + matches, hotdrop := matcher.Match(hook, pkt, "") + if hotdrop { + return RuleDrop, 0 + } + if !matches { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + } + + // All the matchers matched, so run the target. + return rule.Target.Action(pkt) +} + +func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the destination IP. + dest := hdr.DestinationAddress() + matches := true + for i := range filter.Dst { + if dest[i]&filter.DstMask[i] != filter.Dst[i] { + matches = false + break + } + } + if matches == filter.DstInvert { + return false + } + + return true +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go new file mode 100644 index 000000000..8be61f4b1 --- /dev/null +++ b/pkg/tcpip/stack/iptables_targets.go @@ -0,0 +1,155 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// AcceptTarget accepts packets. +type AcceptTarget struct{} + +// Action implements Target.Action. +func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + return RuleAccept, 0 +} + +// DropTarget drops packets. +type DropTarget struct{} + +// Action implements Target.Action. +func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + return RuleDrop, 0 +} + +// ErrorTarget logs an error and drops the packet. It represents a target that +// should be unreachable. +type ErrorTarget struct{} + +// Action implements Target.Action. +func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + log.Debugf("ErrorTarget triggered.") + return RuleDrop, 0 +} + +// UserChainTarget marks a rule as the beginning of a user chain. +type UserChainTarget struct { + Name string +} + +// Action implements Target.Action. +func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { + panic("UserChainTarget should never be called.") +} + +// ReturnTarget returns from the current chain. If the chain is a built-in, the +// hook's underflow should be called. +type ReturnTarget struct{} + +// Action implements Target.Action. +func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) { + return RuleReturn, 0 +} + +// RedirectTarget redirects the packet by modifying the destination port/IP. +// Min and Max values for IP and Ports in the struct indicate the range of +// values which can be used to redirect. +type RedirectTarget struct { + // TODO(gvisor.dev/issue/170): Other flags need to be added after + // we support them. + // RangeProtoSpecified flag indicates single port is specified to + // redirect. + RangeProtoSpecified bool + + // Min address used to redirect. + MinIP tcpip.Address + + // Max address used to redirect. + MaxIP tcpip.Address + + // Min port used to redirect. + MinPort uint16 + + // Max port used to redirect. + MaxPort uint16 +} + +// Action implements Target.Action. +// TODO(gvisor.dev/issue/170): Parse headers without copying. The current +// implementation only works for PREROUTING and calls pkt.Clone(), neither +// of which should be the case. +func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { + newPkt := pkt.Clone() + + // Set network header. + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } + netHeader := header.IPv4(headerView) + newPkt.NetworkHeader = headerView + + hlen := int(netHeader.HeaderLength()) + tlen := int(netHeader.TotalLength()) + newPkt.Data.TrimFront(hlen) + newPkt.Data.CapLength(tlen - hlen) + + // TODO(gvisor.dev/issue/170): Change destination address to + // loopback or interface address on which the packet was + // received. + + // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if + // we need to change dest address (for OUTPUT chain) or ports. + switch protocol := netHeader.TransportProtocol(); protocol { + case header.UDPProtocolNumber: + var udpHeader header.UDP + if newPkt.TransportHeader != nil { + udpHeader = header.UDP(newPkt.TransportHeader) + } else { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + return RuleDrop, 0 + } + udpHeader = header.UDP(hdr) + } + udpHeader.SetDestinationPort(rt.MinPort) + case header.TCPProtocolNumber: + var tcpHeader header.TCP + if newPkt.TransportHeader != nil { + tcpHeader = header.TCP(newPkt.TransportHeader) + } else { + if pkt.Data.Size() < header.TCPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) + } + // TODO(gvisor.dev/issue/170): Need to recompute checksum + // and implement nat connection tracking to support TCP. + tcpHeader.SetDestinationPort(rt.MinPort) + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go new file mode 100644 index 000000000..2ffb55f2a --- /dev/null +++ b/pkg/tcpip/stack/iptables_types.go @@ -0,0 +1,180 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +// A Hook specifies one of the hooks built into the network stack. +// +// Userspace app Userspace app +// ^ | +// | v +// [Input] [Output] +// ^ | +// | v +// | routing +// | | +// | v +// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> +type Hook uint + +// These values correspond to values in include/uapi/linux/netfilter.h. +const ( + // Prerouting happens before a packet is routed to applications or to + // be forwarded. + Prerouting Hook = iota + + // Input happens before a packet reaches an application. + Input + + // Forward happens once it's decided that a packet should be forwarded + // to another host. + Forward + + // Output happens after a packet is written by an application to be + // sent out. + Output + + // Postrouting happens just before a packet goes out on the wire. + Postrouting + + // The total number of hooks. + NumHooks +) + +// A RuleVerdict is what a rule decides should be done with a packet. +type RuleVerdict int + +const ( + // RuleAccept indicates the packet should continue through netstack. + RuleAccept RuleVerdict = iota + + // RuleDrop indicates the packet should be dropped. + RuleDrop + + // RuleJump indicates the packet should jump to another chain. + RuleJump + + // RuleReturn indicates the packet should return to the previous chain. + RuleReturn +) + +// IPTables holds all the tables for a netstack. +type IPTables struct { + // Tables maps table names to tables. User tables have arbitrary names. + Tables map[string]Table + + // Priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. + Priorities map[Hook][]string +} + +// A Table defines a set of chains and hooks into the network stack. It is +// really just a list of rules with some metadata for entrypoints and such. +type Table struct { + // Rules holds the rules that make up the table. + Rules []Rule + + // BuiltinChains maps builtin chains to their entrypoint rule in Rules. + BuiltinChains map[Hook]int + + // Underflows maps builtin chains to their underflow rule in Rules + // (i.e. the rule to execute if the chain returns without a verdict). + Underflows map[Hook]int + + // UserChains holds user-defined chains for the keyed by name. Users + // can give their chains arbitrary names. + UserChains map[string]int + + // Metadata holds information about the Table that is useful to users + // of IPTables, but not to the netstack IPTables code itself. + metadata interface{} +} + +// ValidHooks returns a bitmap of the builtin hooks for the given table. +func (table *Table) ValidHooks() uint32 { + hooks := uint32(0) + for hook := range table.BuiltinChains { + hooks |= 1 << hook + } + return hooks +} + +// Metadata returns the metadata object stored in table. +func (table *Table) Metadata() interface{} { + return table.metadata +} + +// SetMetadata sets the metadata object stored in table. +func (table *Table) SetMetadata(metadata interface{}) { + table.metadata = metadata +} + +// A Rule is a packet processing rule. It consists of two pieces. First it +// contains zero or more matchers, each of which is a specification of which +// packets this rule applies to. If there are no matchers in the rule, it +// applies to any packet. +type Rule struct { + // Filter holds basic IP filtering fields common to every rule. + Filter IPHeaderFilter + + // Matchers is the list of matchers for this rule. + Matchers []Matcher + + // Target is the action to invoke if all the matchers match the packet. + Target Target +} + +// IPHeaderFilter holds basic IP filtering data common to every rule. +type IPHeaderFilter struct { + // Protocol matches the transport protocol. + Protocol tcpip.TransportProtocolNumber + + // Dst matches the destination IP address. + Dst tcpip.Address + + // DstMask masks bits of the destination IP address when comparing with + // Dst. + DstMask tcpip.Address + + // DstInvert inverts the meaning of the destination IP check, i.e. when + // true the filter will match packets that fail the destination + // comparison. + DstInvert bool +} + +// A Matcher is the interface for matching packets. +type Matcher interface { + // Name returns the name of the Matcher. + Name() string + + // Match returns whether the packet matches and whether the packet + // should be "hotdropped", i.e. dropped immediately. This is usually + // used for suspicious packets. + // + // Precondition: packet.NetworkHeader is set. + Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) +} + +// A Target is the interface for taking an action for a packet. +type Target interface { + // Action takes an action on the packet and returns a verdict on how + // traversal should (or should not) continue. If the return value is + // Jump, it also returns the index of the rule to jump to. + Action(packet PacketBuffer) (RuleVerdict, int) +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index a9f4d5dad..c11d62f97 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "log" "math/rand" "time" @@ -240,6 +241,16 @@ type NDPDispatcher interface { // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + // OnDNSSearchListOption will be called when an NDP option with a DNS + // search list has been received. + // + // It is up to the caller to use the domain names in the search list + // for only their valid lifetime. OnDNSSearchListOption may be called + // with new or already known domain names. If called with known domain + // names, their valid lifetimes must be refreshed to lifetime (it may + // be increased, decreased or completely invalidated when lifetime = 0. + OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) + // OnDHCPv6Configuration will be called with an updated configuration that is // available via DHCPv6 for a specified NIC. // @@ -304,6 +315,15 @@ type NDPConfigurations struct { // lifetime(s) of the generated address changes; this option only // affects the generation of new addresses as part of SLAAC. AutoGenGlobalAddresses bool + + // AutoGenAddressConflictRetries determines how many times to attempt to retry + // generation of a permanent auto-generated address in response to DAD + // conflicts. + // + // If the method used to generate the address does not support creating + // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's + // MAC address), then no attempt will be made to resolve the conflict. + AutoGenAddressConflictRetries uint8 } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -361,16 +381,16 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState + // The timer used to send the next router solicitation message. + rtrSolicitTimer *time.Timer + // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState - // The timer used to send the next router solicitation message. - // If routers are being solicited, rtrSolicitTimer MUST NOT be nil. - rtrSolicitTimer *time.Timer - - // The addresses generated by SLAAC. - autoGenAddresses map[tcpip.Address]autoGenAddressState + // The SLAAC prefixes discovered through Router Advertisements' Prefix + // Information option. + slaacPrefixes map[tcpip.Subnet]slaacPrefixState // The last learned DHCPv6 configuration from an NDP RA. dhcpv6Configuration DHCPv6ConfigurationFromNDPRA @@ -392,28 +412,54 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - invalidationTimer tcpip.CancellableTimer + // Timer to invalidate the default router. + // + // May not be nil. + invalidationTimer *tcpip.CancellableTimer } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - invalidationTimer tcpip.CancellableTimer + // Timer to invalidate the on-link prefix. + // + // May not be nil. + invalidationTimer *tcpip.CancellableTimer } -// autoGenAddressState holds data associated with an address generated via -// SLAAC. -type autoGenAddressState struct { - // A reference to the referencedNetworkEndpoint that this autoGenAddressState - // is holding state for. - ref *referencedNetworkEndpoint +// slaacPrefixState holds state associated with a SLAAC prefix. +type slaacPrefixState struct { + // Timer to deprecate the prefix. + // + // May not be nil. + deprecationTimer *tcpip.CancellableTimer - deprecationTimer tcpip.CancellableTimer - invalidationTimer tcpip.CancellableTimer + // Timer to invalidate the prefix. + // + // May not be nil. + invalidationTimer *tcpip.CancellableTimer // Nonzero only when the address is not valid forever. validUntil time.Time + + // Nonzero only when the address is not preferred forever. + preferredUntil time.Time + + // The prefix's permanent address endpoint. + // + // May only be nil when a SLAAC address is being (re-)generated. Otherwise, + // must not be nil as all SLAAC prefixes must have a SLAAC address. + ref *referencedNetworkEndpoint + + // The number of times a permanent address has been generated for the prefix. + // + // Addresses may be regenerated in reseponse to a DAD conflicts. + generationAttempts uint8 + + // The maximum number of times to attempt regeneration of a permanent SLAAC + // address in response to DAD conflicts. + maxGenerationAttempts uint8 } // startDuplicateAddressDetection performs Duplicate Address Detection. @@ -430,7 +476,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref if ref.getKind() != permanentTentative { // The endpoint should be marked as tentative since we are starting DAD. - log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) } // Should not attempt to perform DAD on an address that is currently in the @@ -442,7 +488,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. // See NIC.addAddressLocked. - log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()) + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) } remaining := ndp.configs.DupAddrDetectTransmits @@ -478,7 +524,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref if ref.getKind() != permanentTentative { // The endpoint should still be marked as tentative since we are still // performing DAD on it. - log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()) + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())) } dadDone := remaining == 0 @@ -548,9 +594,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { - log.Fatalf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) } else if c != nil { - log.Fatalf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) } hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) @@ -566,7 +612,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, tcpip.PacketBuffer{Header: hdr}, + }, PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -688,7 +734,16 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { continue } - ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) + addrs, _ := opt.Addresses() + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) + + case header.NDPDNSSearchList: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + domainNames, _ := opt.DomainNames() + ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -733,7 +788,6 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { } rtr.invalidationTimer.StopLocked() - delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. @@ -762,7 +816,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } state := defaultRouterState{ - invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), } @@ -792,7 +846,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } state := onLinkPrefixState{ - invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } @@ -817,7 +871,6 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { } s.invalidationTimer.StopLocked() - delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. @@ -899,23 +952,15 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() - // Check if we already have an auto-generated address for prefix. - for addr, addrState := range ndp.autoGenAddresses { - refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()} - if refAddrWithPrefix.Subnet() != prefix { - continue - } - - // At this point, we know we are refreshing a SLAAC generated IPv6 address - // with the prefix prefix. Do the work as outlined by RFC 4862 section - // 5.5.3.e. - ndp.refreshAutoGenAddressLifetimes(addr, pl, vl) + // Check if we already maintain SLAAC state for prefix. + if _, ok := ndp.slaacPrefixes[prefix]; ok { + // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. + ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl) return } - // We do not already have an address with the prefix prefix. Do the - // work as outlined by RFC 4862 section 5.5.3.d if n is configured - // to auto-generate global addresses by SLAAC. + // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC. if !ndp.configs.AutoGenGlobalAddresses { return } @@ -927,6 +972,8 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform // for prefix. // // pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // If we do not already have an address for this prefix and the valid // lifetime is 0, no need to do anything further, as per RFC 4862 @@ -942,10 +989,83 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { return } + state := slaacPrefixState{ + deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) + } + + ndp.deprecateSLAACAddress(state.ref) + }), + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) + } + + ndp.invalidateSLAACPrefix(prefix, state) + }), + maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, + } + + now := time.Now() + + // The time an address is preferred until is needed to properly generate the + // address. + if pl < header.NDPInfiniteLifetime { + state.preferredUntil = now.Add(pl) + } + + if !ndp.generateSLAACAddr(prefix, &state) { + // We were unable to generate an address for the prefix, we do not nothing + // further as there is no reason to maintain state or timers for a prefix we + // do not have an address for. + return + } + + // Setup the initial timers to deprecate and invalidate prefix. + + if pl < header.NDPInfiniteLifetime && pl != 0 { + state.deprecationTimer.Reset(pl) + } + + if vl < header.NDPInfiniteLifetime { + state.invalidationTimer.Reset(vl) + state.validUntil = now.Add(vl) + } + + ndp.slaacPrefixes[prefix] = state +} + +// generateSLAACAddr generates a SLAAC address for prefix. +// +// Returns true if an address was successfully generated. +// +// Panics if the prefix is not a SLAAC prefix or it already has an address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { + if r := state.ref; r != nil { + panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if state.generationAttempts == state.maxGenerationAttempts { + return false + } + addrBytes := []byte(prefix.ID()) if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey) - } else { + addrBytes = header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + state.generationAttempts, + oIID.SecretKey, + ) + } else if state.generationAttempts == 0 { // Only attempt to generate an interface-specific IID if we have a valid // link address. // @@ -953,137 +1073,140 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // LinkEndpoint.LinkAddress) before reaching this point. linkAddr := ndp.nic.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { - return + return false } // Generate an address within prefix from the modified EUI-64 of ndp's NIC's // Ethernet MAC address. header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } else { + // We have no way to regenerate an address when addresses are not generated + // with opaque IIDs. + return false } - addr := tcpip.Address(addrBytes) - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: validPrefixLenForAutoGen, + + generatedAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: validPrefixLenForAutoGen, + }, } // If the nic already has this address, do nothing further. - if ndp.nic.hasPermanentAddrLocked(addr) { - return + if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) { + return false } // Inform the integrator that we have a new SLAAC address. ndpDisp := ndp.nic.stack.ndpDisp if ndpDisp == nil { - return + return false } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) { // Informed by the integrator not to add the address. - return + return false } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addrWithPrefix, - } - // If the preferred lifetime is zero, then the address should be considered - // deprecated. - deprecated := pl == 0 - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + deprecated := time.Since(state.preferredUntil) >= 0 + ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) if err != nil { - log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) + panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err)) } - state := autoGenAddressState{ - ref: ref, - deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - addrState, ok := ndp.autoGenAddresses[addr] - if !ok { - log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) - } - addrState.ref.deprecated = true - ndp.notifyAutoGenAddressDeprecated(addr) - }), - invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - ndp.invalidateAutoGenAddress(addr) - }), - } - - // Setup the initial timers to deprecate and invalidate this newly generated - // address. + state.generationAttempts++ + state.ref = ref + return true +} - if !deprecated && pl < header.NDPInfiniteLifetime { - state.deprecationTimer.Reset(pl) +// regenerateSLAACAddr regenerates an address for a SLAAC prefix. +// +// If generating a new address for the prefix fails, the prefix will be +// invalidated. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix)) } - if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) - state.validUntil = time.Now().Add(vl) + if ndp.generateSLAACAddr(prefix, &state) { + ndp.slaacPrefixes[prefix] = state + return } - ndp.autoGenAddresses[addr] = state + // We were unable to generate a permanent address for the SLAAC prefix so + // invalidate the prefix as there is no reason to maintain state for a + // SLAAC prefix we do not have an address for. + ndp.invalidateSLAACPrefix(prefix, state) } -// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated -// address addr. +// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. // // pl is the new preferred lifetime. vl is the new valid lifetime. -func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) { - addrState, ok := ndp.autoGenAddresses[addr] +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) { + prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { - log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr) + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix)) } - defer func() { ndp.autoGenAddresses[addr] = addrState }() + defer func() { ndp.slaacPrefixes[prefix] = prefixState }() - // If the preferred lifetime is zero, then the address should be considered - // deprecated. + // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 - wasDeprecated := addrState.ref.deprecated - addrState.ref.deprecated = deprecated - - // Only send the deprecation event if the deprecated status for addr just - // changed from non-deprecated to deprecated. - if !wasDeprecated && deprecated { - ndp.notifyAutoGenAddressDeprecated(addr) + if deprecated { + ndp.deprecateSLAACAddress(prefixState.ref) + } else { + prefixState.ref.deprecated = false } - // If addr was preferred for some finite lifetime before, stop the deprecation - // timer so it can be reset. - addrState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, stop the + // deprecation timer so it can be reset. + prefixState.deprecationTimer.StopLocked() + + now := time.Now() - // Reset the deprecation timer if addr has a finite preferred lifetime. - if !deprecated && pl < header.NDPInfiniteLifetime { - addrState.deprecationTimer.Reset(pl) + // Reset the deprecation timer if prefix has a finite preferred lifetime. + if pl < header.NDPInfiniteLifetime { + if !deprecated { + prefixState.deprecationTimer.Reset(pl) + } + prefixState.preferredUntil = now.Add(pl) + } else { + prefixState.preferredUntil = time.Time{} } - // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address - // + // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: // // 1) If the received Valid Lifetime is greater than 2 hours or greater than - // RemainingLifetime, set the valid lifetime of the address to the + // RemainingLifetime, set the valid lifetime of the prefix to the // advertised Valid Lifetime. // // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the // advertised Valid Lifetime. // - // 3) Otherwise, reset the valid lifetime of the address to 2 hours. + // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. // Handle the infinite valid lifetime separately as we do not keep a timer in // this case. if vl >= header.NDPInfiniteLifetime { - addrState.invalidationTimer.StopLocked() - addrState.validUntil = time.Time{} + prefixState.invalidationTimer.StopLocked() + prefixState.validUntil = time.Time{} return } var effectiveVl time.Duration var rl time.Duration - // If the address was originally set to be valid forever, assume the remaining + // If the prefix was originally set to be valid forever, assume the remaining // time to be the maximum possible value. - if addrState.validUntil == (time.Time{}) { + if prefixState.validUntil == (time.Time{}) { rl = header.NDPInfiniteLifetime } else { - rl = time.Until(addrState.validUntil) + rl = time.Until(prefixState.validUntil) } if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { @@ -1094,58 +1217,78 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t effectiveVl = MinPrefixInformationValidLifetimeForUpdate } - addrState.invalidationTimer.StopLocked() - addrState.invalidationTimer.Reset(effectiveVl) - addrState.validUntil = time.Now().Add(effectiveVl) + prefixState.invalidationTimer.StopLocked() + prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.validUntil = now.Add(effectiveVl) } -// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr -// has been deprecated. -func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) { +// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP +// dispatcher that ref has been deprecated. +// +// deprecateSLAACAddress does nothing if ref is already deprecated. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { + if ref.deprecated { + return + } + + ref.deprecated = true if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: validPrefixLenForAutoGen, - }) + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) } } -// invalidateAutoGenAddress invalidates an auto-generated address. +// invalidateSLAACPrefix invalidates a SLAAC prefix. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { - if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { - return +func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { + if r := state.ref; r != nil { + // Since we are already invalidating the prefix, do not invalidate the + // prefix when removing the address. + if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err)) + } } - ndp.nic.removePermanentAddressLocked(addr) + ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated -// address's resources from ndp. If the stack has an NDP dispatcher, it will -// be notified that addr has been invalidated. -// -// Returns true if ndp had resources for addr to cleanup. +// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's +// resources. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { - state, ok := ndp.autoGenAddresses[addr] - if !ok { - return false +func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) } - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() - delete(ndp.autoGenAddresses, addr) + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress { + return + } - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: validPrefixLenForAutoGen, - }) + if !invalidatePrefix { + // If the prefix is not being invalidated, disassociate the address from the + // prefix and do nothing further. + state.ref = nil + ndp.slaacPrefixes[prefix] = state + return } - return true + ndp.cleanupSLAACPrefixResources(prefix, state) +} + +// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// +// Panics if the SLAAC prefix is not known. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { + state.deprecationTimer.StopLocked() + state.invalidationTimer.StopLocked() + delete(ndp.slaacPrefixes, prefix) } // cleanupState cleans up ndp's state. @@ -1163,21 +1306,21 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - linkLocalAddrs := 0 - for addr := range ndp.autoGenAddresses { + linkLocalPrefixes := 0 + for prefix, state := range ndp.slaacPrefixes { // RFC 4862 section 5 states that routers are also expected to generate a // link-local address so we do not invalidate them if we are cleaning up // host-only state. - if hostOnly && linkLocalSubnet.Contains(addr) { - linkLocalAddrs++ + if hostOnly && prefix == linkLocalSubnet { + linkLocalPrefixes++ continue } - ndp.invalidateAutoGenAddress(addr) + ndp.invalidateSLAACPrefix(prefix, state) } - if got := len(ndp.autoGenAddresses); got != linkLocalAddrs { - log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs) + if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { + panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) } for prefix := range ndp.onLinkPrefixes { @@ -1185,7 +1328,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { } if got := len(ndp.onLinkPrefixes); got != 0 { - log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got) + panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)) } for router := range ndp.defaultRouters { @@ -1193,7 +1336,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { } if got := len(ndp.defaultRouters); got != 0 { - log.Fatalf("ndp: still have discovered default routers after cleaning up; found = %d", got) + panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) } } @@ -1235,9 +1378,9 @@ func (ndp *ndpState) startSolicitingRouters() { // header.IPv6AllRoutersMulticastAddress is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { - log.Fatalf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) } else if c != nil { - log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1267,7 +1410,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, tcpip.PacketBuffer{Header: hdr}, + }, PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 98b1c807c..6dd460984 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -133,6 +133,12 @@ type ndpRDNSSEvent struct { rdnss ndpRDNSS } +type ndpDNSSLEvent struct { + nicID tcpip.NICID + domainNames []string + lifetime time.Duration +} + type ndpDHCPv6Event struct { nicID tcpip.NICID configuration stack.DHCPv6ConfigurationFromNDPRA @@ -150,6 +156,8 @@ type ndpDispatcher struct { rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent + dnsslC chan ndpDNSSLEvent + routeTable []tcpip.Route dhcpv6ConfigurationC chan ndpDHCPv6Event } @@ -257,6 +265,17 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } +// Implements stack.NDPDispatcher.OnDNSSearchListOption. +func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { + if n.dnsslC != nil { + n.dnsslC <- ndpDNSSLEvent{ + nicID, + domainNames, + lifetime, + } + } +} + // Implements stack.NDPDispatcher.OnDHCPv6Configuration. func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { if c := n.dhcpv6ConfigurationC; c != nil { @@ -406,8 +425,7 @@ func TestDADResolve(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } - // Address should not be considered bound to the NIC yet - // (DAD ongoing). + // Address should not be considered bound to the NIC yet (DAD ongoing). addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) @@ -416,10 +434,9 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - // Wait for the remaining time - some delta (500ms), to - // make sure the address is still not resolved. - const delta = 500 * time.Millisecond - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) + // Make sure the address does not resolve before the resolution time has + // passed. + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout) addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) @@ -430,13 +447,7 @@ func TestDADResolve(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. + case <-time.After(2 * defaultAsyncEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { @@ -476,7 +487,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View(), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), @@ -602,7 +613,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -631,6 +642,12 @@ func TestDADFail(t *testing.T) { if want := (tcpip.AddressWithPrefix{}); addr != want { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } + + // Attempting to add the address again should not fail if the address's + // state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } }) } } @@ -639,8 +656,9 @@ func TestDADStop(t *testing.T) { const nicID = 1 tests := []struct { - name string - stopFn func(t *testing.T, s *stack.Stack) + name string + stopFn func(t *testing.T, s *stack.Stack) + skipFinalAddrCheck bool }{ // Tests to make sure that DAD stops when an address is removed. { @@ -661,6 +679,19 @@ func TestDADStop(t *testing.T) { } }, }, + + // Tests to make sure that DAD stops when the NIC is removed. + { + name: "Remove NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("RemoveNIC(%d): %s", nicID, err) + } + }, + // The NIC is removed so we can't check its addresses after calling + // stopFn. + skipFinalAddrCheck: true, + }, } for _, test := range tests { @@ -710,12 +741,15 @@ func TestDADStop(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + + if !test.skipFinalAddrCheck { + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } } // Should not have sent more than 1 NS message. @@ -901,7 +935,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -936,14 +970,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -952,7 +986,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -960,7 +994,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -969,7 +1003,7 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. @@ -1017,8 +1051,6 @@ func TestNoRouterDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), } @@ -1057,8 +1089,6 @@ func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) str // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), } @@ -1099,8 +1129,6 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, @@ -1202,8 +1230,6 @@ func TestRouterDiscovery(t *testing.T) { // TestRouterDiscoveryMaxRouters tests that only // stack.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, @@ -1270,8 +1296,6 @@ func TestNoPrefixDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 1), } @@ -1311,8 +1335,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st // TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered on-link prefix when the dispatcher asks it not to. func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - t.Parallel() - prefix, subnet, _ := prefixSubnetAddr(0, "") ndpDisp := ndpDispatcher{ @@ -1356,8 +1378,6 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } func TestPrefixDiscovery(t *testing.T) { - t.Parallel() - prefix1, subnet1, _ := prefixSubnetAddr(0, "") prefix2, subnet2, _ := prefixSubnetAddr(1, "") prefix3, subnet3, _ := prefixSubnetAddr(2, "") @@ -1546,8 +1566,6 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // TestPrefixDiscoveryMaxRouters tests that only // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, @@ -1642,8 +1660,6 @@ func TestNoAutoGenAddr(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } @@ -1968,7 +1984,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // addr2 is deprecated but if explicitly requested, it should be used. fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) } // Another PI w/ 0 preferred lifetime should not result in a deprecation @@ -1981,7 +1997,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { } expectPrimaryAddr(addr1) if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) } // Refresh lifetimes of addr generated from prefix2. @@ -2093,7 +2109,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make @@ -2106,7 +2122,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } expectPrimaryAddr(addr2) if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Refresh lifetimes for addr of prefix1. @@ -2130,7 +2146,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since it is not deprecated. expectPrimaryAddr(addr2) if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Wait for addr of prefix1 to be invalidated. @@ -2393,8 +2409,6 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { }, } - const delta = 500 * time.Millisecond - // This Run will not return until the parallel tests finish. // // We need this because we need to do some teardown work after the @@ -2447,24 +2461,21 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { // to test.evl. // - // Make sure we do not get any invalidation - // events until atleast 500ms (delta) before - // test.evl. + // The address should not be invalidated until the effective valid + // lifetime has passed. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - delta): + case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout): } - // Wait for another second (2x delta), but now - // we expect the invalidation event. + // Wait for the invalidation event. select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - - case <-time.After(2 * delta): + case <-time.After(2 * defaultAsyncEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -2476,8 +2487,6 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { // by the user, its resources will be cleaned up and an invalidation event will // be sent to the integrator. func TestAutoGenAddrRemoval(t *testing.T) { - t.Parallel() - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) ndpDisp := ndpDispatcher{ @@ -2534,8 +2543,6 @@ func TestAutoGenAddrRemoval(t *testing.T) { // TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously // assigned to the NIC but is in the permanentExpired state. func TestAutoGenAddrAfterRemoval(t *testing.T) { - t.Parallel() - const nicID = 1 prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) @@ -2582,7 +2589,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { AddressWithPrefix: addr2, } if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) } // addr2 should be more preferred now since it is at the front of the primary // list. @@ -2647,8 +2654,6 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that // is already assigned to the NIC, the static address remains. func TestAutoGenAddrStaticConflict(t *testing.T) { - t.Parallel() - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) ndpDisp := ndpDispatcher{ @@ -2704,8 +2709,6 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use // opaque interface identifiers when configured to do so. func TestAutoGenAddrWithOpaqueIID(t *testing.T) { - t.Parallel() - const nicID = 1 const nicName = "nic1" var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte @@ -2805,12 +2808,465 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } } +// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an +// auto-generated IPv6 address in response to a DAD conflict. +func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { + const nicID = 1 + const nicName = "nic" + const dadTransmits = 1 + const retransmitTimer = time.Second + const maxMaxRetries = 3 + const lifetimeSeconds = 10 + + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { + for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + subnet tcpip.Subnet + triggerSLAACFn func(e *channel.Endpoint) + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + subnet: subnet, + triggerSLAACFn: func(e *channel.Endpoint) { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + AutoGenAddressConflictRetries: maxRetries, + }, + autoGenLinkLocal: true, + subnet: header.IPv6LinkLocalPrefix.Subnet(), + triggerSLAACFn: func(e *channel.Endpoint) {}, + }, + } + + for _, addrType := range addrTypes { + maxRetries := maxRetries + numFailures := numFailures + addrType := addrType + + t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + addrType.triggerSLAACFn(e) + + // Simulate DAD conflicts so the address is regenerated. + for i := uint8(0); i < numFailures; i++ { + addrBytes := []byte(addrType.subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Should not have any addresses assigned to the NIC. + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); mainAddr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + } + + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Attempting to add the address manually should not fail if the + // address's state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + } + if err := s.RemoveAddress(nicID, addr.Address); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + } + + // Should not have any addresses assigned to the NIC. + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); mainAddr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + } + + // If we had less failures than generation attempts, we should have an + // address after DAD resolves. + if maxRetries+1 > numFailures { + addrBytes := []byte(addrType.subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if mainAddr != addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr) + } + } + + // Should not attempt address regeneration again. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } + } + } +} + +// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is +// not made for SLAAC addresses generated with an IID based on the NIC's link +// address. +func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { + const nicID = 1 + const dadTransmits = 1 + const retransmitTimer = time.Second + const maxRetries = 3 + const lifetimeSeconds = 10 + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + subnet tcpip.Subnet + triggerSLAACFn func(e *channel.Endpoint) + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + subnet: subnet, + triggerSLAACFn: func(e *channel.Endpoint) { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + AutoGenAddressConflictRetries: maxRetries, + }, + autoGenLinkLocal: true, + subnet: header.IPv6LinkLocalPrefix.Subnet(), + triggerSLAACFn: func(e *channel.Endpoint) {}, + }, + } + + for _, addrType := range addrTypes { + addrType := addrType + + t.Run(addrType.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + addrType.triggerSLAACFn(e) + + addrBytes := []byte(addrType.subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Should not attempt address regeneration. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } +} + +// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address +// generation in response to DAD conflicts does not refresh the lifetimes. +func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { + const nicID = 1 + const nicName = "nic" + const dadTransmits = 1 + const retransmitTimer = 2 * time.Second + const failureTimer = time.Second + const maxRetries = 1 + const lifetimeSeconds = 5 + + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + addrBytes := []byte(subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Simulate a DAD conflict after some time has passed. + time.Sleep(failureTimer) + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Let the next address resolve. + addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) + expectAutoGenAddrEvent(addr, newAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + // Address should be deprecated/invalidated after the lifetime expires. + // + // Note, the remaining lifetime is calculated from when the PI was first + // processed. Since we wait for some time before simulating a DAD conflict + // and more time for the new address to resolve, the new address is only + // expected to be valid for the remaining time. The DAD conflict should + // not have reset the lifetimes. + // + // We expect either just the invalidation event or the deprecation event + // followed by the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if e.eventType == deprecatedAddr { + if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") + } + } else { + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + } + case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for auto gen addr event") + } +} + // TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event // to the integrator when an RA is received with the NDP Recursive DNS Server // option with at least one valid address. func TestNDPRecursiveDNSServerDispatch(t *testing.T) { - t.Parallel() - tests := []struct { name string opt header.NDPRecursiveDNSServer @@ -2902,11 +3358,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ // We do not expect more than a single RDNSS // event at any time for this test. @@ -2953,11 +3405,115 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } } +// TestNDPDNSSearchListDispatch tests that the integrator is informed when an +// NDP DNS Search List option is received with at least one domain name in the +// search list. +func TestNDPDNSSearchListDispatch(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dnsslC: make(chan ndpDNSSLEvent, 3), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + optSer := header.NDPOptionsSerializer{ + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 0, 0, + 2, 'h', 'i', + 0, + }), + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 'i', + 0, + 2, 'a', 'm', + 2, 'm', 'e', + 0, + }), + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 1, 0, + 3, 'x', 'y', 'z', + 0, + 5, 'h', 'e', 'l', 'l', 'o', + 5, 'w', 'o', 'r', 'l', 'd', + 0, + 4, 't', 'h', 'i', 's', + 2, 'i', 's', + 1, 'a', + 4, 't', 'e', 's', 't', + 0, + }), + } + expected := []struct { + domainNames []string + lifetime time.Duration + }{ + { + domainNames: []string{ + "hi", + }, + lifetime: 0, + }, + { + domainNames: []string{ + "i", + "am.me", + }, + lifetime: time.Second, + }, + { + domainNames: []string{ + "xyz", + "hello.world", + "this.is.a.test", + }, + lifetime: 256 * time.Second, + }, + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + + for i, expected := range expected { + select { + case dnssl := <-ndpDisp.dnsslC: + if dnssl.nicID != nicID { + t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID) + } + if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" { + t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff) + } + if dnssl.lifetime != expected.lifetime { + t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime) + } + default: + t.Fatal("expected a DNSSL event") + } + } + + // Should have no more DNSSL options. + select { + case <-ndpDisp.dnsslC: + t.Fatal("unexpectedly got a DNSSL event") + default: + } +} + // TestCleanupNDPState tests that all discovered routers and prefixes, and // auto-generated addresses are invalidated when a NIC becomes a router. func TestCleanupNDPState(t *testing.T) { - t.Parallel() - const ( lifetimeSeconds = 5 maxRouterAndPrefixEvents = 4 @@ -2983,11 +3539,12 @@ func TestCleanupNDPState(t *testing.T) { cleanupFn func(t *testing.T, s *stack.Stack) keepAutoGenLinkLocal bool maxAutoGenAddrEvents int + skipFinalAddrCheck bool }{ // A NIC should still keep its auto-generated link-local address when // becoming a router. { - name: "Forwarding Enable", + name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() s.SetForwarding(true) @@ -2998,7 +3555,7 @@ func TestCleanupNDPState(t *testing.T) { // A NIC should cleanup all NDP state when it is disabled. { - name: "NIC Disable", + name: "Disable NIC", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() @@ -3012,6 +3569,26 @@ func TestCleanupNDPState(t *testing.T) { keepAutoGenLinkLocal: false, maxAutoGenAddrEvents: 6, }, + + // A NIC should cleanup all NDP state when it is removed. + { + name: "Remove NIC", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + + if err := s.RemoveNIC(nicID1); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err) + } + if err := s.RemoveNIC(nicID2); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + // The NICs are removed so we can't check their addresses after calling + // stopFn. + skipFinalAddrCheck: true, + }, } for _, test := range tests { @@ -3230,35 +3807,37 @@ func TestCleanupNDPState(t *testing.T) { t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) } - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + if !test.skipFinalAddrCheck { + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } } - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + if containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + } + if containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } // Should not get any more events (invalidation timers should have been @@ -3377,8 +3956,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // TestRouterSolicitation tests the initial Router Solicitations that are sent // when a NIC newly becomes enabled. func TestRouterSolicitation(t *testing.T) { - t.Parallel() - const nicID = 1 tests := []struct { @@ -3395,13 +3972,22 @@ func TestRouterSolicitation(t *testing.T) { effectiveMaxRtrSolicitDelay time.Duration }{ { - name: "Single RS with delay", + name: "Single RS with 2s delay and interval", expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 1, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: time.Second, - effectiveMaxRtrSolicitDelay: time.Second, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, + maxRtrSolicitDelay: 2 * time.Second, + effectiveMaxRtrSolicitDelay: 2 * time.Second, + }, + { + name: "Single RS with 4s delay and interval", + expectedSrcAddr: header.IPv6Any, + maxRtrSolicit: 1, + rtrSolicitInt: 4 * time.Second, + effectiveRtrSolicitInt: 4 * time.Second, + maxRtrSolicitDelay: 4 * time.Second, + effectiveMaxRtrSolicitDelay: 4 * time.Second, }, { name: "Two RS with delay", @@ -3409,8 +3995,8 @@ func TestRouterSolicitation(t *testing.T) { nicAddr: llAddr1, expectedSrcAddr: llAddr1, maxRtrSolicit: 2, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, maxRtrSolicitDelay: 500 * time.Millisecond, effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, }, @@ -3424,8 +4010,8 @@ func TestRouterSolicitation(t *testing.T) { header.NDPSourceLinkLayerAddressOption(linkAddr1), }, maxRtrSolicit: 1, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, maxRtrSolicitDelay: 0, effectiveMaxRtrSolicitDelay: 0, }, @@ -3475,6 +4061,7 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() + e := channelLinkWithHeaderLength{ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), headerLength: test.linkHeaderLen, @@ -3482,7 +4069,8 @@ func TestRouterSolicitation(t *testing.T) { e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() p, ok := e.ReadContext(ctx) if !ok { t.Fatal("timed out waiting for packet") @@ -3512,7 +4100,8 @@ func TestRouterSolicitation(t *testing.T) { } waitForNothing := func(timeout time.Duration) { t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") } @@ -3543,15 +4132,19 @@ func TestRouterSolicitation(t *testing.T) { } for ; remaining > 0; remaining-- { - waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) - waitForPkt(defaultAsyncEventTimeout) + if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout) + waitForPkt(2 * defaultAsyncEventTimeout) + } else { + waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout) + } } // Make sure no more RS. if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout) + waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout) + waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) } // Make sure the counter got properly @@ -3565,27 +4158,27 @@ func TestRouterSolicitation(t *testing.T) { } func TestStopStartSolicitingRouters(t *testing.T) { - t.Parallel() - const nicID = 1 + const delay = 0 const interval = 500 * time.Millisecond - const delay = time.Second const maxRtrSolicitations = 3 tests := []struct { name string startFn func(t *testing.T, s *stack.Stack) - stopFn func(t *testing.T, s *stack.Stack) + // first is used to tell stopFn that it is being called for the first time + // after router solicitations were last enabled. + stopFn func(t *testing.T, s *stack.Stack, first bool) }{ // Tests that when forwarding is enabled or disabled, router solicitations // are stopped or started, respectively. { - name: "Forwarding enabled and disabled", + name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() s.SetForwarding(false) }, - stopFn: func(t *testing.T, s *stack.Stack) { + stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() s.SetForwarding(true) }, @@ -3594,7 +4187,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Tests that when a NIC is enabled or disabled, router solicitations // are started or stopped, respectively. { - name: "NIC disabled and enabled", + name: "Enable and disable NIC", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() @@ -3602,7 +4195,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } }, - stopFn: func(t *testing.T, s *stack.Stack) { + stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() if err := s.DisableNIC(nicID); err != nil { @@ -3610,6 +4203,25 @@ func TestStopStartSolicitingRouters(t *testing.T) { } }, }, + + // Tests that when a NIC is removed, router solicitations are stopped. We + // cannot start router solications on a removed NIC. + { + name: "Remove NIC", + stopFn: func(t *testing.T, s *stack.Stack, first bool) { + t.Helper() + + // Only try to remove the NIC the first time stopFn is called since it's + // impossible to remove an already removed NIC. + if !first { + return + } + + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) + } + }, + }, } for _, test := range tests { @@ -3623,7 +4235,6 @@ func TestStopStartSolicitingRouters(t *testing.T) { p, ok := e.ReadContext(ctx) if !ok { t.Fatal("timed out waiting for packet") - return } if p.Proto != header.IPv6ProtocolNumber { @@ -3648,12 +4259,12 @@ func TestStopStartSolicitingRouters(t *testing.T) { } // Stop soliciting routers. - test.stopFn(t, s) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) + test.stopFn(t, s, true /* first */) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { - // A single RS may have been sent before forwarding was enabled. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout) + // A single RS may have been sent before solicitations were stopped. + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) defer cancel() if _, ok = e.ReadContext(ctx); ok { t.Fatal("should not have sent more than one RS message") @@ -3662,19 +4273,24 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. - test.stopFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + test.stopFn(t, s, false /* first */) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") } + // If test.startFn is nil, there is no way to restart router solications. + if test.startFn == nil { + return + } + // Start soliciting routers. test.startFn(t, s) waitForPkt(delay + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") @@ -3683,7 +4299,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3e6196aee..0c2b1f36a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -15,7 +15,7 @@ package stack import ( - "log" + "fmt" "reflect" "sort" "strings" @@ -54,7 +54,7 @@ type NIC struct { primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint endpoints map[NetworkEndpointID]*referencedNetworkEndpoint addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mcastJoins map[NetworkEndpointID]uint32 // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint @@ -121,15 +121,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC } nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) - nic.mu.mcastJoins = make(map[NetworkEndpointID]int32) + nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) nic.mu.ndp = ndpState{ - nic: nic, - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), + nic: nic, + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } // Register supported packet endpoint protocols. @@ -165,8 +165,17 @@ func (n *NIC) disable() *tcpip.Error { } n.mu.Lock() - defer n.mu.Unlock() + err := n.disableLocked() + n.mu.Unlock() + return err +} +// disableLocked disables n. +// +// It undoes the work done by enable. +// +// n MUST be locked. +func (n *NIC) disableLocked() *tcpip.Error { if !n.mu.enabled { return nil } @@ -189,7 +198,7 @@ func (n *NIC) disable() *tcpip.Error { } // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } } @@ -305,24 +314,33 @@ func (n *NIC) remove() *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - // Detach from link endpoint, so no packet comes in. - n.linkEP.Attach(nil) + n.disableLocked() + + // TODO(b/151378115): come up with a better way to pick an error than the + // first one. + var err *tcpip.Error + + // Forcefully leave multicast groups. + for nid := range n.mu.mcastJoins { + if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { + err = tempErr + } + } // Remove permanent and permanentTentative addresses, so no packet goes out. - var errs []*tcpip.Error for nid, ref := range n.mu.endpoints { switch ref.getKind() { case permanentTentative, permanent: - if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil { - errs = append(errs, err) + if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { + err = tempErr } } } - if len(errs) > 0 { - return errs[0] - } - return nil + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) + + return err } // becomeIPv6Router transitions n into an IPv6 router. @@ -461,7 +479,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn // Should never happen as we got r from the primary IPv6 endpoint list and // ScopeForIPv6Address only returns an error if addr is not an IPv6 // address. - log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err) + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) } cs = append(cs, ipv6AddrCandidate{ @@ -473,7 +491,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn remoteScope, err := header.ScopeForIPv6Address(remoteAddr) if err != nil { // primaryIPv6Endpoint should never be called with an invalid IPv6 address. - log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err) + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) } // Sort the addresses as per RFC 6724 section 5 rules 1-3. @@ -969,6 +987,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { for i, ref := range refs { if ref == r { n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + refs[len(refs)-1] = nil break } } @@ -993,35 +1012,42 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress } - isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) + switch r.protocol { + case header.IPv6ProtocolNumber: + return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */) + default: + r.expireLocked() + return nil + } +} + +func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error { + addr := r.addrWithPrefix() + + isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) if isIPv6Unicast { - // If we are removing a tentative IPv6 unicast address, stop - // DAD. - if kind == permanentTentative { - n.mu.ndp.stopDuplicateAddressDetection(addr) - } + n.mu.ndp.stopDuplicateAddressDetection(addr.Address) // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. if r.configType == slaac { - n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation) } } - r.setKind(permanentExpired) - if !r.decRefLocked() { - // The endpoint still has references to it. - return nil - } + r.expireLocked() // At this point the endpoint is deleted. // If we are removing an IPv6 unicast address, leave the solicited-node // multicast address. + // + // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node + // multicast group may be left by user action. if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr) - if err := n.leaveGroupLocked(snmc); err != nil { + snmc := header.SolicitedNodeAddr(addr.Address) + if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } } @@ -1081,26 +1107,31 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.leaveGroupLocked(addr) + return n.leaveGroupLocked(addr, false /* force */) } // leaveGroupLocked decrements the count for the given multicast address, and // when it reaches zero removes the endpoint for this address. n MUST be locked // before leaveGroupLocked is called. -func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +// +// If force is true, then the count for the multicast addres is ignored and the +// endpoint will be removed immediately. +func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { id := NetworkEndpointID{addr} - joins := n.mu.mcastJoins[id] - switch joins { - case 0: + joins, ok := n.mu.mcastJoins[id] + if !ok { // There are no joins with this address on this NIC. return tcpip.ErrBadLocalAddress - case 1: - // This is the last one, clean up. - if err := n.removePermanentAddressLocked(addr); err != nil { - return err - } } - n.mu.mcastJoins[id] = joins - 1 + + joins-- + if force || joins == 0 { + // There are no outstanding joins or we are forced to leave, clean up. + delete(n.mu.mcastJoins, id) + return n.removePermanentAddressLocked(addr) + } + + n.mu.mcastJoins[id] = joins return nil } @@ -1113,9 +1144,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -1126,7 +1158,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1171,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1186,6 +1218,16 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() return } + + // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. + if protocol == header.IPv4ProtocolNumber { + ipt := n.stack.IPTables() + if ok := ipt.Check(Prerouting, pkt); !ok { + // iptables is telling us to drop the packet. + return + } + } + if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return @@ -1201,10 +1243,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() return } - defer r.Release() - - r.LocalLinkAddress = n.linkEP.LinkAddress() - r.RemoteLinkAddress = remote // Found a NIC. n := r.ref.nic @@ -1213,24 +1251,33 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() n.mu.RUnlock() if ok { + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, pkt) ref.decRef() - } else { - // n doesn't have a destination endpoint. - // Send the packet out of n. - pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) - pkt.Data.RemoveFirst() - - // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + r.Release() + return + } + + // n doesn't have a destination endpoint. + // Send the packet out of n. + // TODO(b/128629022): move this logic to route.WritePacket. + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) + // forwarder will release route. + return } + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() + r.Release() + return } + + // The link-address resolution finished immediately. + n.forwardPacket(&r, protocol, pkt) + r.Release() return } @@ -1240,9 +1287,24 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } } +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) + } + + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) +} + // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1256,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1288,7 +1351,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -1299,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } @@ -1351,10 +1415,12 @@ func (n *NIC) isAddrTentative(addr tcpip.Address) bool { return ref.getKind() == permanentTentative } -// dupTentativeAddrDetected attempts to inform n that a tentative addr -// is a duplicate on a link. +// dupTentativeAddrDetected attempts to inform n that a tentative addr is a +// duplicate on a link. // -// dupTentativeAddrDetected will delete the tentative address if it exists. +// dupTentativeAddrDetected will remove the tentative address if it exists. If +// the address was generated via SLAAC, an attempt will be made to generate a +// new address. func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -1368,7 +1434,17 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - return n.removePermanentAddressLocked(addr) + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a + // new address will be generated for it. + if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil { + return err + } + + if ref.configType == slaac { + n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet()) + } + + return nil } // setNDPConfigs sets the NDP configurations for n. @@ -1496,6 +1572,13 @@ type referencedNetworkEndpoint struct { deprecated bool } +func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { + return tcpip.AddressWithPrefix{ + Address: r.ep.ID().LocalAddress, + PrefixLen: r.ep.PrefixLen(), + } +} + func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) } @@ -1523,6 +1606,13 @@ func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) } +// expireLocked decrements the reference count and marks the permanent endpoint +// as expired. +func (r *referencedNetworkEndpoint) expireLocked() { + r.setKind(permanentExpired) + r.decRefLocked() +} + // decRef decrements the ref count and cleans up the endpoint once it reaches // zero. func (r *referencedNetworkEndpoint) decRef() { @@ -1532,14 +1622,11 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. Returns true if the endpoint was removed. -func (r *referencedNetworkEndpoint) decRefLocked() bool { +// locked. +func (r *referencedNetworkEndpoint) decRefLocked() { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) - return true } - - return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index edaee3b86..d672fc157 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -17,7 +17,6 @@ package stack import ( "testing" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -45,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go new file mode 100644 index 000000000..7d36f8e84 --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer.go @@ -0,0 +1,78 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +type PacketBuffer struct { + // PacketBufferEntry is used to build an intrusive list of + // PacketBuffers. + PacketBufferEntry + + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. + // + // TODO(gvisor.dev/issue/170): Forwarded packets don't currently + // populate Header, but should. This will be doable once early parsing + // (https://github.com/google/gvisor/pull/1995) is supported. + Header buffer.Prependable + + // These fields are used by both inbound and outbound packets. They + // typically overlap with the Data and Header fields. + // + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other slices (either Data or Header). + // SR dosen't support this, so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View + + // Hash is the transport layer hash of this packet. A value of zero + // indicates no valid hash has been set. + Hash uint32 + + // Owner is implemented by task to get the uid and gid. + // Only set for locally generated packets. + Owner tcpip.PacketOwner +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +// +// Clone also does not deep copy any of its other fields. +func (pk PacketBuffer) Clone() PacketBuffer { + pk.Data = pk.Data.Clone(nil) + return pk +} diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go new file mode 100644 index 000000000..421fb5c15 --- /dev/null +++ b/pkg/tcpip/stack/rand.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + mathrand "math/rand" + + "gvisor.dev/gvisor/pkg/sync" +) + +// lockedRandomSource provides a threadsafe rand.Source. +type lockedRandomSource struct { + mu sync.Mutex + src mathrand.Source +} + +func (r *lockedRandomSource) Int63() (n int64) { + r.mu.Lock() + n = r.src.Int63() + r.mu.Unlock() + return n +} + +func (r *lockedRandomSource) Seed(seed int64) { + r.mu.Lock() + r.src.Seed(seed) + r.mu.Unlock() +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index f9fd8f18f..23ca9ee03 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,12 +67,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +100,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, pkt PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +118,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +150,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -180,7 +180,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +189,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -242,15 +242,15 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -265,7 +265,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, pkt PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -322,7 +322,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -354,7 +354,7 @@ const ( // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets // out through the implementer's data link endpoint. When a link header exists, -// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// it sets each PacketBuffer's LinkHeader field before passing it up the // stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is @@ -385,7 +385,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the // given route. pkts must not be zero length. @@ -393,7 +393,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. @@ -401,6 +401,9 @@ type LinkEndpoint interface { // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. + // + // Attach will be called with a nil dispatcher if the receiver's associated + // NIC is being removed. Attach(dispatcher NetworkDispatcher) // IsAttached returns whether a NetworkDispatcher is attached to the @@ -423,7 +426,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index f565aafb2..a0e5e0300 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -153,7 +153,7 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } @@ -168,29 +168,32 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack return err } -// WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +// WritePackets writes a list of n packets through the given route and returns +// the number of packets written. +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - payloadSize := 0 - for i := 0; i < n; i++ { - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) - payloadSize += pkts[i].DataSize + + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Header.UsedLength() + writtenBytes += pb.Data.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 13354d884..41398a1b6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,7 +20,9 @@ package stack import ( + "bytes" "encoding/binary" + mathrand "math/rand" "sync/atomic" "time" @@ -31,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" @@ -51,7 +52,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -428,7 +429,7 @@ type Stack struct { // tables are the iptables packet filtering and manipulation rules. The are // protected by tablesMu.` - tables iptables.IPTables + tables IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -462,6 +463,14 @@ type Stack struct { // opaqueIIDOpts hold the options for generating opaque interface identifiers // (IIDs) as outlined by RFC 7217. opaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // forwarder holds the packets that wait for their link-address resolutions + // to complete, and forwards them when each resolution is done. + forwarder *forwardQueue + + // randomGenerator is an injectable pseudo random generator that can be + // used when a random number is required. + randomGenerator *mathrand.Rand } // UniqueID is an abstract generator of unique identifiers. @@ -522,9 +531,16 @@ type Options struct { // this is non-nil. RawFactory RawFactory - // OpaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. + // OpaqueIIDOpts hold the options for generating opaque interface + // identifiers (IIDs) as outlined by RFC 7217. OpaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // RandSource is an optional source to use to generate random + // numbers. If omitted it defaults to a Source seeded by the data + // returned by rand.Read(). + // + // RandSource must be thread-safe. + RandSource mathrand.Source } // TransportEndpointInfo holds useful information about a transport endpoint @@ -620,6 +636,13 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + randSrc := opts.RandSource + if randSrc == nil { + // Source provided by mathrand.NewSource is not thread-safe so + // we wrap it in a simple thread-safe version. + randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() @@ -641,6 +664,8 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, + forwarder: newForwardQueue(), + randomGenerator: mathrand.New(randSrc), } // Add specified network protocols. @@ -733,7 +758,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1696,7 +1721,7 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() iptables.IPTables { +func (s *Stack) IPTables() IPTables { s.tablesMu.RLock() t := s.tables s.tablesMu.RUnlock() @@ -1704,7 +1729,7 @@ func (s *Stack) IPTables() iptables.IPTables { } // SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt iptables.IPTables) { +func (s *Stack) SetIPTables(ipt IPTables) { s.tablesMu.Lock() s.tables = ipt s.tablesMu.Unlock() @@ -1814,6 +1839,12 @@ func (s *Stack) Seed() uint32 { return s.seed } +// Rand returns a reference to a pseudo random generator that can be used +// to generate random numbers as required. +func (s *Stack) Rand() *mathrand.Rand { + return s.randomGenerator +} + func generateRandUint32() uint32 { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { @@ -1821,3 +1852,16 @@ func generateRandUint32() uint32 { } return binary.LittleEndian.Uint32(b) } + +func generateRandInt64() int64 { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + panic(err) + } + buf := bytes.NewReader(b) + var v int64 + if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { + panic(err) + } + return v +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e15db40fb..d45d2cc1f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -90,21 +90,23 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return @@ -126,7 +128,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -141,7 +143,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, tcpip.PacketBuffer{ + f.HandlePacket(r, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } @@ -153,11 +155,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -255,7 +257,7 @@ type linkEPWithMockedAttach struct { // Attach implements stack.LinkEndpoint.Attach. func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { l.LinkEndpoint.Attach(d) - l.attached = true + l.attached = d != nil } func (l *linkEPWithMockedAttach) isAttached() bool { @@ -287,7 +289,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -299,7 +301,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -311,7 +313,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -322,7 +324,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -334,7 +336,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -356,7 +358,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -414,7 +416,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -566,7 +568,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) } if !e.isAttached() { - t.Fatalf("link endpoint not attached to a network disatcher") + t.Fatal("link endpoint not attached to a network dispatcher") } }) } @@ -631,196 +633,240 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { checkNIC(false) } -func TestRoutesWithDisabledNIC(t *testing.T) { - const unspecifiedNIC = 0 - const nicID1 = 1 - const nicID2 = 2 - +func TestRemoveUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - ep1 := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) } +} - addr1 := tcpip.Address("\x01") - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } +func TestRemoveNIC(t *testing.T) { + const nicID = 1 - ep2 := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addr2 := tcpip.Address("\x02") - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") } - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) } + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) + } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } +} - // Test routes to odd address. - testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) - testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) - testRoute(t, s, nicID1, addr1, "\x05", addr1) +func TestRouteWithDownNIC(t *testing.T) { + tests := []struct { + name string + downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + }{ + { + name: "Disabled NIC", + downFn: (*stack.Stack).DisableNIC, + upFn: (*stack.Stack).EnableNIC, + }, + + // Once a NIC is removed, it cannot be brought up. + { + name: "Removed NIC", + downFn: (*stack.Stack).RemoveNIC, + }, + } - // Test routes to even address. - testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) - testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) - testRoute(t, s, nicID2, addr2, "\x06", addr2) - - // Disabling NIC1 should result in no routes to odd addresses. Routes to even - // addresses should continue to be available as NIC2 is still enabled. - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - nic1Dst := tcpip.Address("\x05") - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - nic2Dst := tcpip.Address("\x06") - testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) - testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) - testRoute(t, s, nicID2, addr2, nic2Dst, addr2) - - // Disabling NIC2 should result in no routes to even addresses. No route - // should be available to any address as routes to odd addresses were made - // unavailable by disabling NIC1 above. - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - - // Enabling NIC1 should make routes to odd addresses available again. Routes - // to even addresses should continue to be unavailable as NIC2 is still - // disabled. - if err := s.EnableNIC(nicID1); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) - } - testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) - testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) - testRoute(t, s, nicID1, addr1, nic1Dst, addr1) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) -} - -func TestRouteWritePacketWithDisabledNIC(t *testing.T) { const unspecifiedNIC = 0 const nicID1 = 1 const nicID2 = 2 + const addr1 = tcpip.Address("\x01") + const addr2 = tcpip.Address("\x02") + const nic1Dst = tcpip.Address("\x05") + const nic2Dst = tcpip.Address("\x06") + + setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } - addr1 := tcpip.Address("\x01") - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } - addr2 := tcpip.Address("\x02") - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + return s, ep1, ep2 } - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) + // Tests that routes through a down NIC are not used when looking up a route + // for a destination. + t.Run("Find", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, _, _ := setup(t) + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Bringing NIC1 down should result in no routes to odd addresses. Routes to + // even addresses should continue to be available as NIC2 is still up. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Bringing NIC2 down should result in no routes to even addresses. No + // route should be available to any address as routes to odd addresses + // were made unavailable by bringing NIC1 down above. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + if upFn := test.upFn; upFn != nil { + // Bringing NIC1 up should make routes to odd addresses available + // again. Routes to even addresses should continue to be unavailable + // as NIC2 is still down. + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + } + }) } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) - } + }) - nic1Dst := tcpip.Address("\x05") - r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) - } - defer r1.Release() + // Tests that writing a packet using a Route through a down NIC fails. + t.Run("WritePacket", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, ep1, ep2 := setup(t) - nic2Dst := tcpip.Address("\x06") - r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) - } - defer r2.Release() + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() - // If we failed to get routes r1 or r2, we cannot proceed with the test. - if t.Failed() { - t.FailNow() - } + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() - buf := buffer.View([]byte{1}) - testSend(t, r1, ep1, buf) - testSend(t, r2, ep2, buf) + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } - // Writes with Routes that use the disabled NIC1 should fail. - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testSend(t, r2, ep2, buf) + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) - // Writes with Routes that use the disabled NIC2 should fail. - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + // Writes with Routes that use NIC1 after being brought down should fail. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) - // Writes with Routes that use the re-enabled NIC1 should succeed. - // TODO(b/147015577): Should we instead completely invalidate all Routes that - // were bound to a disabled NIC at some point? - if err := s.EnableNIC(nicID1); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) - } - testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + // Writes with Routes that use NIC2 after being brought down should fail. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + if upFn := test.upFn; upFn != nil { + // Writes with Routes that use NIC1 after being brought up should + // succeed. + // + // TODO(b/147015577): Should we instead completely invalidate all + // Routes that were bound to a NIC that was brought down at some + // point? + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + } + }) + } + }) } func TestRoutes(t *testing.T) { @@ -1401,19 +1447,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } // If the NIC doesn't exist, it won't work. if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } } @@ -1439,12 +1485,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { } nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } // Set the initial route table. @@ -1459,10 +1505,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // When an interface is given, the route for a broadcast goes through it. r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } // When an interface is not given, it consults the route table. @@ -2213,7 +2259,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2240,56 +2286,84 @@ func TestNICStats(t *testing.T) { } func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) + const nicID1 = 1 + const nicID2 = 2 + const dstAddr = tcpip.Address("\x03") - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + tests := []struct { + name string + headerLen uint16 + }{ + { + name: "Zero header length", + }, + { + name: "Non-zero header length", + headerLen: 16, + }, } - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + s.SetForwarding(true) - // Route all packets to address 3 to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) - } + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) + } - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + ep2 := channelLinkWithHeaderLength{ + Endpoint: channel.New(10, defaultMTU, ""), + headerLength: test.headerLen, + } + if err := s.CreateNIC(nicID2, &ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) + } - if _, ok := ep2.Read(); !ok { - t.Fatal("Packet not forwarded") - } + // Route all packets to dstAddr to NIC 2. + { + subnet, err := tcpip.NewSubnet(dstAddr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) + } - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } + // Send a packet to dstAddr. + buf := buffer.NewView(30) + buf[0] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not forwarded") + } + + // Test that the link's MaxHeaderLength is honoured. + if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { + t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + } + + // Test that forwarding increments Tx stats correctly. + if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + } + + if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + }) } } @@ -2327,7 +2401,7 @@ func TestNICContextPreservation(t *testing.T) { t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) } if got, want := nicinfo.Context == test.want, true; got != want { - t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) } }) } @@ -2696,7 +2770,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { - t.Fatalf("NewSubnet failed:", err) + t.Fatalf("NewSubnet failed: %v", err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } @@ -2710,11 +2784,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // permanentExpired kind. r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) if err != nil { - t.Fatal("FindRoute failed:", err) + t.Fatalf("FindRoute failed: %v", err) } defer r.Release() if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } // @@ -2726,7 +2800,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } @@ -2734,7 +2808,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // make sure the new peb was respected. // (The address should just be promoted now). if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[1].ProtocolAddresses { @@ -2767,11 +2841,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // GetMainNICAddress; else, our original address // should be returned. if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } addr, err = s.GetMainNICAddress(1, fakeNetNumber) if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + t.Fatalf("s.GetMainNICAddress failed: %v", err) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { @@ -3010,6 +3084,50 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { } } +// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 +// address after leaving its solicited node multicast address does not result in +// an error. +func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + } + + // The NIC should have joined addr1's solicited node multicast address. + snmc := header.SolicitedNodeAddr(addr1) + in, err := s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if !in { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) + } + + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) + } + in, err = s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if in { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) + } + + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + } +} + func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { const nicID = 1 @@ -3060,8 +3178,6 @@ func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { // TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC // was disabled have DAD performed on them when the NIC is enabled. func TestDoDADWhenNICEnabled(t *testing.T) { - t.Parallel() - const dadTransmits = 1 const retransmitTimer = time.Second const nicID = 1 diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 778c0a4d6..9a33ed375 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,9 +15,9 @@ package stack import ( + "container/heap" "fmt" "math/rand" - "sort" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -35,7 +35,7 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]*endpointsByNic + endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint @@ -46,11 +46,11 @@ type transportEndpoints struct { func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] + epsByNIC, ok := eps.endpoints[id] if !ok { return } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + if !epsByNIC.unregisterEndpoint(bindToDevice, ep) { return } delete(eps.endpoints, id) @@ -66,18 +66,85 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { return es } -type endpointsByNic struct { +// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in +// descending order of match quality. If a call to yield returns false, +// iterEndpointsLocked stops iteration and returns immediately. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { + // Try to find a match with the id as provided. + if ep, ok := eps.endpoints[id]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the local address. + nid := id + + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the remote part. + nid.LocalAddress = id.LocalAddress + nid.RemoteAddress = "" + nid.RemotePort = 0 + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with only the local port. + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } +} + +// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in +// descending order of match quality. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { + var matchedEPs []*endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEPs = append(matchedEPs, ep) + return true + }) + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given id. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { + var matchedEP *endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEP = ep + return false + }) + return matchedEP +} + +type endpointsByNIC struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 } -func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() var eps []TransportEndpoint - for _, ep := range epsByNic.endpoints { + for _, ep := range epsByNIC.endpoints { eps = append(eps, ep.transportEndpoints()...) } return eps @@ -85,13 +152,13 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { - epsByNic.mu.RLock() +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { + epsByNIC.mu.RLock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } } @@ -100,29 +167,29 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { mpep.handlePacketAll(r, id, pkt) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNic.seed) + transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(r, transEP, id, pkt) - epsByNic.mu.RUnlock() + epsByNIC.mu.RUnlock() return } transEP.HandlePacket(r, id, pkt) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() - mpep, ok := epsByNic.endpoints[n.ID()] + mpep, ok := epsByNIC.endpoints[n.ID()] if !ok { - mpep, ok = epsByNic.endpoints[0] + mpep, ok = epsByNIC.endpoints[0] } if !ok { return @@ -132,40 +199,41 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() - if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { - // There was already a bind. - return multiPortEp.singleRegisterEndpoint(t, reusePort) + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] + if !ok { + multiPortEp = &multiPortEndpoint{ + demux: d, + netProto: netProto, + transProto: transProto, + reuse: reusePort, + } + epsByNIC.endpoints[bindToDevice] = multiPortEp } - // This is a new binding. - multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.reuse = reusePort - epsByNic.endpoints[bindToDevice] = multiPortEp return multiPortEp.singleRegisterEndpoint(t, reusePort) } -// unregisterEndpoint returns true if endpointsByNic has to be unregistered. -func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() - multiPortEp, ok := epsByNic.endpoints[bindToDevice] +// unregisterEndpoint returns true if endpointsByNIC has to be unregistered. +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { return false } if multiPortEp.unregisterEndpoint(t) { - delete(epsByNic.endpoints, bindToDevice) + delete(epsByNIC.endpoints, bindToDevice) } - return len(epsByNic.endpoints) == 0 + return len(epsByNIC.endpoints) == 0 } // transportDemuxer demultiplexes packets targeted at a transport endpoint @@ -183,7 +251,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -197,7 +265,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for proto := range stack.transportProtocols { protoIDs := protocolIDs{netProto, proto} d.protocol[protoIDs] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]*endpointsByNic), + endpoints: make(map[TransportEndpointID]*endpointsByNIC), } qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) if isQueued { @@ -222,6 +290,35 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } +type transportEndpointHeap []TransportEndpoint + +var _ heap.Interface = (*transportEndpointHeap)(nil) + +func (h *transportEndpointHeap) Len() int { + return len(*h) +} + +func (h *transportEndpointHeap) Less(i, j int) bool { + return (*h)[i].UniqueID() < (*h)[j].UniqueID() +} + +func (h *transportEndpointHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] +} + +func (h *transportEndpointHeap) Push(x interface{}) { + *h = append(*h, x.(TransportEndpoint)) +} + +func (h *transportEndpointHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + old[n-1] = nil + *h = old[:n-1] + return x +} + // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -237,15 +334,14 @@ type multiPortEndpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - endpointsArr []TransportEndpoint - endpointsMap map[TransportEndpoint]int + endpoints transportEndpointHeap // reuse indicates if more than one endpoint is allowed. reuse bool } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + eps := append([]TransportEndpoint(nil), ep.endpoints...) ep.mu.RUnlock() return eps } @@ -262,8 +358,8 @@ func reciprocalScale(val, n uint32) uint32 { // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpointsArr) == 1 { - return mpep.endpointsArr[0] + if len(mpep.endpoints) == 1 { + return mpep.endpoints[0] } payload := []byte{ @@ -279,29 +375,26 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) - return mpep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) + return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] - for i, endpoint := range ep.endpointsArr { - // HandlePacket takes ownership of pkt, so each endpoint needs - // its own copy except for the final one. - if i == len(ep.endpointsArr)-1 { - if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) - break - } - endpoint.HandlePacket(r, id, pkt) - break - } + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. + for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) - continue + } else { + endpoint.HandlePacket(r, id, pkt.Clone()) } - endpoint.HandlePacket(r, id, pkt.Clone()) + } + if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + } else { + endpoint.HandlePacket(r, id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -312,26 +405,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo ep.mu.Lock() defer ep.mu.Unlock() - if len(ep.endpointsArr) > 0 { + if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if !ep.reuse || !reusePort { return tcpip.ErrPortInUse } } - // A new endpoint is added into endpointsArr and its index there is saved in - // endpointsMap. This will allow us to remove endpoint from the array fast. - ep.endpointsMap[t] = len(ep.endpointsArr) - ep.endpointsArr = append(ep.endpointsArr, t) + heap.Push(&ep.endpoints, t) - // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints - // can be restored in the same order. - sort.Slice(ep.endpointsArr, func(i, j int) bool { - return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() - }) - for i, e := range ep.endpointsArr { - ep.endpointsMap[e] = i - } return nil } @@ -340,21 +422,13 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { ep.mu.Lock() defer ep.mu.Unlock() - idx, ok := ep.endpointsMap[t] - if !ok { - return false - } - delete(ep.endpointsMap, t) - l := len(ep.endpointsArr) - if l > 1 { - // The last endpoint in endpointsArr is moved instead of the deleted one. - lastEp := ep.endpointsArr[l-1] - ep.endpointsArr[idx] = lastEp - ep.endpointsMap[lastEp] = idx - ep.endpointsArr = ep.endpointsArr[0 : l-1] - return false + for i, endpoint := range ep.endpoints { + if endpoint == t { + heap.Remove(&ep.endpoints, i) + break + } } - return true + return len(ep.endpoints) == 0 } func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { @@ -371,19 +445,16 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - if epsByNic, ok := eps.endpoints[id]; ok { - // There was already a binding. - return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) - } - - // This is a new binding. - epsByNic := &endpointsByNic{ - endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + epsByNIC, ok := eps.endpoints[id] + if !ok { + epsByNIC = &endpointsByNIC{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), + } + eps.endpoints[id] = epsByNIC } - eps.endpoints[id] = epsByNic - return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -396,84 +467,60 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN } } -var loopbackSubnet = func() tcpip.Subnet { - sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") - if err != nil { - panic(err) - } - return sn -}() - // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false } - eps.mu.RLock() - - // Determine which transport endpoint or endpoints to deliver this packet to. // If the packet is a UDP broadcast or multicast, then find all matching - // transport endpoints. If the packet is a TCP packet with a non-unicast - // source or destination address, then do nothing further and instruct - // the caller to do the same. - var destEps []*endpointsByNic - switch protocol { - case header.UDPProtocolNumber: - if isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, id) - break - } - - if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) + // transport endpoints. + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + eps.mu.RLock() + destEPs := eps.findAllEndpointsLocked(id) + eps.mu.RUnlock() + // Fail if we didn't find at least one matching transport endpoint. + if len(destEPs) == 0 { + r.Stats().UDP.UnknownPortErrors.Increment() + return false } - - case header.TCPProtocolNumber: - if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { - // TCP can only be used to communicate between a single - // source and a single destination; the addresses must - // be unicast. - eps.mu.RUnlock() - r.Stats().TCP.InvalidSegmentsReceived.Increment() - return true + // handlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEPs[:len(destEPs)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + return true + } - fallthrough - - default: - if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) - } + // If the packet is a TCP packet with a non-unicast source or destination + // address, then do nothing further and instruct the caller to do the same. + if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single source and a + // single destination; the addresses must be unicast. + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true } + eps.mu.RLock() + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() - - // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 { - // UDP packet could not be delivered to an unknown destination port. + if ep == nil { if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() } return false } - - // HandlePacket takes ownership of pkt, so each endpoint needs its own - // copy except for the final one. - for _, ep := range destEps[:len(destEps)-1] { - ep.handlePacket(r, id, pkt.Clone()) - } - destEps[len(destEps)-1].handlePacket(r, id, pkt) - + ep.handlePacket(r, id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -497,99 +544,53 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false } - // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, id) + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() - - // Fail if we didn't find one. if ep == nil { return false } - // Deliver the packet. ep.handleControlPacket(n, id, typ, extra, pkt) - return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { - var matchedEPs []*endpointsByNic - // Try to find a match with the id as provided. - if ep, ok := eps.endpoints[id]; ok { - matchedEPs = append(matchedEPs, ep) - } - - // Try to find a match with the id minus the local address. - nid := id - - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) - } - - // Try to find a match with the id minus the remote part. - nid.LocalAddress = id.LocalAddress - nid.RemoteAddress = "" - nid.RemotePort = 0 - if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) - } - - // Try to find a match with only the local port. - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) - } - return matchedEPs -} - // findTransportEndpoint find a single endpoint that most closely matches the provided id. func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil } - // Try to find the endpoint. + eps.mu.RLock() - epsByNic := d.findEndpointLocked(eps, id) - // Fail if we didn't find one. - if epsByNic == nil { + epsByNIC := eps.findEndpointLocked(id) + if epsByNIC == nil { eps.mu.RUnlock() return nil } - epsByNic.mu.RLock() + epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return nil } } - ep := selectEndpoint(id, mpep, epsByNic.seed) - epsByNic.mu.RUnlock() + ep := selectEndpoint(id, mpep, epsByNIC.seed) + epsByNIC.mu.RUnlock() return ep } -// findEndpointLocked returns the endpoint that most closely matches the given -// id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { - return matchedEPs[0] - } - return nil -} - // registerRawEndpoint registers the given endpoint with the dispatcher such // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw @@ -601,8 +602,8 @@ func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNum } eps.mu.Lock() - defer eps.mu.Unlock() eps.rawEndpoints = append(eps.rawEndpoints, ep) + eps.mu.Unlock() return nil } @@ -616,13 +617,16 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } eps.mu.Lock() - defer eps.mu.Unlock() for i, rawEP := range eps.rawEndpoints { if rawEP == ep { - eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) - return + lastIdx := len(eps.rawEndpoints) - 1 + eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx] + eps.rawEndpoints[lastIdx] = nil + eps.rawEndpoints = eps.rawEndpoints[:lastIdx] + break } } + eps.mu.Unlock() } func isMulticastOrBroadcast(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 5e9237de9..2474a7db3 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -31,84 +31,58 @@ import ( ) const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testPort = 4096 + testSrcAddrV4 = "\x0a\x00\x00\x01" + testDstAddrV4 = "\x0a\x00\x00\x02" + + testDstPort = 1234 + testSrcPort = 4096 ) type testContext struct { - t *testing.T linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createV6Endpoint(v6only bool) { - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } + wq waiter.Queue } // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { channelEp := channel.New(256, mtu, "") if err := s.CreateNIC(linkEpID, channelEp); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress IPv4 failed: %v", err) + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { + t.Fatalf("AddAddress IPv4 failed: %s", err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress IPv6 failed: %v", err) + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { + t.Fatalf("AddAddress IPv6 failed: %s", err) } } s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, + {Destination: header.IPv4EmptySubnet, NIC: 1}, + {Destination: header.IPv6EmptySubnet, NIC: 1}, }) return &testContext{ - t: t, s: s, linkEps: linkEps, } } type headers struct { - srcPort uint16 - dstPort uint16 + srcPort, dstPort uint16 } func newPayload() []byte { @@ -119,6 +93,47 @@ func newPayload() []byte { return b } +func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { + buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: 0x80, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: testSrcAddrV4, + DstAddr: testDstAddrV4, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) +} + func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) @@ -130,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -143,15 +158,17 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), }) } @@ -167,38 +184,48 @@ func TestTransportDemuxerRegister(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) } } -// TestReuseBindToDevice injects varied packets on input devices and checks that +// TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. -func TestDistribution(t *testing.T) { +func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { - reuse int + reuse bool bindToDevice tcpip.NICID } for _, test := range []struct { name string // endpoints will received the inject packets. endpoints []endpointSockopts - // wantedDistribution is the wanted ratio of packets received on each + // wantDistributions is the want ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[tcpip.NICID][]float64 + wantDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. @@ -209,9 +236,9 @@ func TestDistribution(t *testing.T) { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {0, 1}, - {0, 2}, - {0, 3}, + {reuse: false, bindToDevice: 1}, + {reuse: false, bindToDevice: 2}, + {reuse: false, bindToDevice: 3}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. @@ -226,12 +253,12 @@ func TestDistribution(t *testing.T) { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {1, 1}, - {1, 1}, - {1, 2}, - {1, 2}, - {1, 2}, - {1, 0}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to @@ -245,17 +272,17 @@ func TestDistribution(t *testing.T) { }, }, } { - t.Run(test.name, func(t *testing.T) { - for device, wantedDistribution := range test.wantedDistributions { - t.Run(string(device), func(t *testing.T) { + for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } { + for device, wantDistribution := range test.wantDistributions { + t.Run(test.name+protoName+string(device), func(t *testing.T) { var devices []tcpip.NICID - for d := range test.wantedDistributions { + for d := range test.wantDistributions { devices = append(devices, d) } c := newDualTestContextMultiNIC(t, defaultMTU, devices) - defer c.cleanup() - - c.createV6Endpoint(false) eps := make(map[tcpip.Endpoint]int) @@ -269,9 +296,9 @@ func TestDistribution(t *testing.T) { defer close(ch) var err *tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } eps[ep] = i @@ -282,22 +309,31 @@ func TestDistribution(t *testing.T) { }(ep) defer ep.Close() - reusePortOption := tcpip.ReusePortOption(endpoint.reuse) - if err := ep.SetSockOpt(reusePortOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { + t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) + } + + var dstAddr tcpip.Address + switch netProtoNum { + case ipv4.ProtocolNumber: + dstAddr = testDstAddrV4 + case ipv6.ProtocolNumber: + dstAddr = testDstAddrV6 + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) } - if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } npackets := 100000 nports := 10000 - if got, want := len(test.endpoints), len(wantedDistribution); got != want { + if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } ports := make(map[uint16]tcpip.Endpoint) @@ -306,17 +342,22 @@ func TestDistribution(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, - &headers{ - srcPort: testPort + port, - dstPort: stackPort}, - device) + hdrs := &headers{ + srcPort: testSrcPort + port, + dstPort: testDstPort, + } + switch netProtoNum { + case ipv4.ProtocolNumber: + c.sendV4Packet(payload, hdrs, device) + case ipv6.ProtocolNumber: + c.sendV6Packet(payload, hdrs, device) + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) + } - var addr tcpip.FullAddress ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + if _, _, err := ep.Read(nil); err != nil { + t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ if i < nports { @@ -332,17 +373,17 @@ func TestDistribution(t *testing.T) { // Check that a packet distribution is as expected. for ep, i := range eps { - wantedRatio := wantedDistribution[i] - wantedRecv := wantedRatio * float64(npackets) + wantRatio := wantDistribution[i] + wantRecv := wantRatio * float64(npackets) actualRecv := stats[ep] actualRatio := float64(stats[ep]) / float64(npackets) // The deviation is less than 10%. - if math.Abs(actualRatio-wantedRatio) > 0.05 { - t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + if math.Abs(actualRatio-wantRatio) > 0.05 { + t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) } } }) } - }) + } } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5d1da2f8b..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -57,6 +56,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } +func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} + func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } @@ -87,7 +88,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -214,7 +215,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -231,7 +232,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -242,8 +243,8 @@ func (f *fakeTransportEndpoint) State() uint32 { func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} -func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { - return iptables.IPTables{}, nil +func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) { + return stack.IPTables{}, nil } func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} @@ -288,7 +289,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { return true } @@ -368,7 +369,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -379,7 +380,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -390,7 +391,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -445,7 +446,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -456,7 +457,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -467,7 +468,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -622,7 +623,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: req.ToVectorisedView(), }) @@ -641,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } |