diff options
Diffstat (limited to 'pkg/tcpip/stack/stack_test.go')
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 1538 |
1 files changed, 1333 insertions, 205 deletions
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9a8906a0d..d45d2cc1f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,12 +27,16 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( @@ -58,7 +62,7 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int proto *fakeNetworkProtocol @@ -71,7 +75,7 @@ func (f *fakeNetworkEndpoint) MTU() uint32 { } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicid + return f.nicID } func (f *fakeNetworkEndpoint) PrefixLen() int { @@ -86,28 +90,30 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } + pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + pkt.Data.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -122,37 +128,38 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := hdr.Prepend(fakeNetHeaderLen) + b := pkt.Header.Prepend(fakeNetHeaderLen) b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) - if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) + if r.Loop&stack.PacketLoop != 0 { + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + f.HandlePacket(r, stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } - return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -197,9 +204,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, @@ -230,10 +237,33 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +// Close implements TransportProtocol.Close. +func (*fakeNetworkProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeNetworkProtocol) Wait() {} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } +// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify +// that LinkEndpoint.Attach was called. +type linkEPWithMockedAttach struct { + stack.LinkEndpoint + attached bool +} + +// Attach implements stack.LinkEndpoint.Attach. +func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { + l.LinkEndpoint.Attach(d) + l.attached = d != nil +} + +func (l *linkEPWithMockedAttach) isAttached() bool { + return l.attached +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -259,7 +289,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -269,7 +301,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -279,7 +313,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -288,7 +324,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -298,7 +336,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -318,7 +358,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -373,7 +416,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -490,6 +535,340 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr } } +// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to +// a NetworkDispatcher when the NIC is created. +func TestAttachToLinkEndpointImmediately(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + nicOpts stack.NICOptions + }{ + { + name: "Create enabled NIC", + nicOpts: stack.NICOptions{Disabled: false}, + }, + { + name: "Create disabled NIC", + nicOpts: stack.NICOptions{Disabled: true}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + + if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") + } + }) + } +} + +func TestDisableUnknownNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + } +} + +func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := loopback.New() + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + checkNIC := func(enabled bool) { + t.Helper() + + allNICInfo := s.NICInfo() + nicInfo, ok := allNICInfo[nicID] + if !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } else if nicInfo.Flags.Running != enabled { + t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) + } + + if got := s.CheckNIC(nicID); got != enabled { + t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) + } + } + + // NIC should initially report itself as disabled. + checkNIC(false) + + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + checkNIC(true) + + // If the NIC is not reporting a correct enabled status, we cannot trust the + // next check so end the test here. + if t.Failed() { + t.FailNow() + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + checkNIC(false) +} + +func TestRemoveUnknownNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + } +} + +func TestRemoveNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") + } + + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) + } + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) + } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } +} + +func TestRouteWithDownNIC(t *testing.T) { + tests := []struct { + name string + downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + }{ + { + name: "Disabled NIC", + downFn: (*stack.Stack).DisableNIC, + upFn: (*stack.Stack).EnableNIC, + }, + + // Once a NIC is removed, it cannot be brought up. + { + name: "Removed NIC", + downFn: (*stack.Stack).RemoveNIC, + }, + } + + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + const addr1 = tcpip.Address("\x01") + const addr2 = tcpip.Address("\x02") + const nic1Dst = tcpip.Address("\x05") + const nic2Dst = tcpip.Address("\x06") + + setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + return s, ep1, ep2 + } + + // Tests that routes through a down NIC are not used when looking up a route + // for a destination. + t.Run("Find", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, _, _ := setup(t) + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Bringing NIC1 down should result in no routes to odd addresses. Routes to + // even addresses should continue to be available as NIC2 is still up. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Bringing NIC2 down should result in no routes to even addresses. No + // route should be available to any address as routes to odd addresses + // were made unavailable by bringing NIC1 down above. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + if upFn := test.upFn; upFn != nil { + // Bringing NIC1 up should make routes to odd addresses available + // again. Routes to even addresses should continue to be unavailable + // as NIC2 is still down. + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + } + }) + } + }) + + // Tests that writing a packet using a Route through a down NIC fails. + t.Run("WritePacket", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, ep1, ep2 := setup(t) + + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() + + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() + + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } + + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use NIC1 after being brought down should fail. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use NIC2 after being brought down should fail. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + if upFn := test.upFn; upFn != nil { + // Writes with Routes that use NIC1 after being brought up should + // succeed. + // + // TODO(b/147015577): Should we instead completely invalidate all + // Routes that were bound to a NIC that was brought down at some + // point? + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + } + }) + } + }) +} + func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has @@ -668,11 +1047,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } } -func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { +func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { t.Helper() - info, ok := s.NICInfo()[nicid] + info, ok := s.NICInfo()[nicID] if !ok { - t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + t.Fatalf("NICInfo() failed to find nicID=%d", nicID) } if len(addr) == 0 { // No address given, verify that there is no address assigned to the NIC. @@ -705,7 +1084,7 @@ func TestEndpointExpiration(t *testing.T) { localAddrByte byte = 0x01 remoteAddr tcpip.Address = "\x03" noAddr tcpip.Address = "" - nicid tcpip.NICID = 1 + nicID tcpip.NICID = 1 ) localAddr := tcpip.Address([]byte{localAddrByte}) @@ -717,7 +1096,7 @@ func TestEndpointExpiration(t *testing.T) { }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -734,13 +1113,13 @@ func TestEndpointExpiration(t *testing.T) { buf[0] = localAddrByte if promiscuous { - if err := s.SetPromiscuousMode(nicid, true); err != nil { + if err := s.SetPromiscuousMode(nicID, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } } if spoofing { - if err := s.SetSpoofing(nicid, true); err != nil { + if err := s.SetSpoofing(nicID, true); err != nil { t.Fatal("SetSpoofing failed:", err) } } @@ -748,7 +1127,7 @@ func TestEndpointExpiration(t *testing.T) { // 1. No Address yet, send should only work for spoofing, receive for // promiscuous mode. //----------------------- - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -763,20 +1142,20 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -791,10 +1170,10 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -812,10 +1191,10 @@ func TestEndpointExpiration(t *testing.T) { // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -831,10 +1210,10 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) testSend(t, r, ep, nil) @@ -842,17 +1221,17 @@ func TestEndpointExpiration(t *testing.T) { // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -1068,19 +1447,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } // If the NIC doesn't exist, it won't work. if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } } @@ -1106,12 +1485,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { } nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } // Set the initial route table. @@ -1126,10 +1505,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // When an interface is given, the route for a broadcast goes through it. r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } // When an interface is not given, it consults the route table. @@ -1645,12 +2024,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } func TestAddAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1658,7 +2037,7 @@ func TestAddAddress(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) for _, addrLen := range []int{4, 16} { address := addrGen.next(addrLen) - if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { t.Fatalf("AddAddress(address=%s) failed: %s", address, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1667,17 +2046,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1694,24 +2073,24 @@ func TestAddProtocolAddress(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) } expectedAddresses = append(expectedAddresses, protocolAddress) } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1722,7 +2101,7 @@ func TestAddAddressWithOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1732,17 +2111,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1761,7 +2140,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) } expectedAddresses = append(expectedAddresses, protocolAddress) @@ -1769,10 +2148,95 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } +func TestCreateNICWithOptions(t *testing.T) { + type callArgsAndExpect struct { + nicID tcpip.NICID + opts stack.NICOptions + err *tcpip.Error + } + + tests := []struct { + desc string + calls []callArgsAndExpect + }{ + { + desc: "DuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth1"}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth2"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "DuplicateName", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "lo"}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{Name: "lo"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "Unnamed", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{}, + err: nil, + }, + }, + }, + { + desc: "UnnamedDuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + s := stack.New(stack.Options{}) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + for _, call := range test.calls { + if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { + t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) + } + } + }) + } +} + func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, @@ -1795,7 +2259,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1820,150 +2286,386 @@ func TestNICStats(t *testing.T) { } func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) + const nicID1 = 1 + const nicID2 = 2 + const dstAddr = tcpip.Address("\x03") - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + tests := []struct { + name string + headerLen uint16 + }{ + { + name: "Zero header length", + }, + { + name: "Non-zero header length", + headerLen: 16, + }, } - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + s.SetForwarding(true) + + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) + } + + ep2 := channelLinkWithHeaderLength{ + Endpoint: channel.New(10, defaultMTU, ""), + headerLength: test.headerLen, + } + if err := s.CreateNIC(nicID2, &ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) + } + + // Route all packets to dstAddr to NIC 2. + { + subnet, err := tcpip.NewSubnet(dstAddr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) + } + + // Send a packet to dstAddr. + buf := buffer.NewView(30) + buf[0] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not forwarded") + } + + // Test that the link's MaxHeaderLength is honoured. + if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { + t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + } + + // Test that forwarding increments Tx stats correctly. + if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + } + + if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + }) } +} - // Route all packets to address 3 to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) +// TestNICContextPreservation tests that you can read out via stack.NICInfo the +// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. +func TestNICContextPreservation(t *testing.T) { + var ctx *int + tests := []struct { + name string + opts stack.NICOptions + want stack.NICContext + }{ + { + "context_set", + stack.NICOptions{Context: ctx}, + ctx, + }, + { + "context_not_set", + stack.NICOptions{}, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + id := tcpip.NICID(1) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { + t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) + } + nicinfos := s.NICInfo() + nicinfo, ok := nicinfos[id] + if !ok { + t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) + } + if got, want := nicinfo.Context == test.want, true; got != want { + t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + } + }) } +} - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) +// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local +// addresses. +func TestNICAutoGenLinkLocalAddr(t *testing.T) { + const nicID = 1 - select { - case <-ep2.C: - default: - t.Fatal("Packet not forwarded") + var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte + n, err := rand.Read(secretKey[:]) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) } - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + nicNameFunc := func(_ tcpip.NICID, name string) string { + return name } -} -// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// (or lack there-of if disabled (default)). Note, DAD will be disabled in -// these tests. -func TestNICAutoGenAddr(t *testing.T) { tests := []struct { - name string - autoGen bool - linkAddr tcpip.LinkAddress - shouldGen bool + name string + nicName string + autoGen bool + linkAddr tcpip.LinkAddress + iidOpts stack.OpaqueInterfaceIdentifierOptions + shouldGen bool + expectedAddr tcpip.Address }{ { - "Disabled", - false, - linkAddr1, - false, + name: "Disabled", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + shouldGen: false, + }, + { + name: "Disabled without OIID options", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:], + }, + shouldGen: false, + }, + + // Tests for EUI64 based addresses. + { + name: "EUI64 Enabled", + autoGen: true, + linkAddr: linkAddr1, + shouldGen: true, + expectedAddr: header.LinkLocalAddr(linkAddr1), + }, + { + name: "EUI64 Empty MAC", + autoGen: true, + shouldGen: false, + }, + { + name: "EUI64 Invalid MAC", + autoGen: true, + linkAddr: "\x01\x02\x03", + shouldGen: false, + }, + { + name: "EUI64 Multicast MAC", + autoGen: true, + linkAddr: "\x01\x02\x03\x04\x05\x06", + shouldGen: false, }, { - "Enabled", - true, - linkAddr1, - true, + name: "EUI64 Unspecified MAC", + autoGen: true, + linkAddr: "\x00\x00\x00\x00\x00\x00", + shouldGen: false, }, + + // Tests for Opaque IID based addresses. { - "Nil MAC", - true, - tcpip.LinkAddress([]byte(nil)), - false, + name: "OIID Enabled", + nicName: "nic1", + autoGen: true, + linkAddr: linkAddr1, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), }, + // These are all cases where we would not have generated a + // link-local address if opaque IIDs were disabled. { - "Empty MAC", - true, - tcpip.LinkAddress(""), - false, + name: "OIID Empty MAC and empty nicName", + autoGen: true, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:1], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]), }, { - "Invalid MAC", - true, - tcpip.LinkAddress("\x01\x02\x03"), - false, + name: "OIID Invalid MAC", + nicName: "test", + autoGen: true, + linkAddr: "\x01\x02\x03", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:2], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]), }, { - "Multicast MAC", - true, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - false, + name: "OIID Multicast MAC", + nicName: "test2", + autoGen: true, + linkAddr: "\x01\x02\x03\x04\x05\x06", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:3], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]), }, { - "Unspecified MAC", - true, - tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), - false, + name: "OIID Unspecified MAC and nil SecretKey", + nicName: "test3", + autoGen: true, + linkAddr: "\x00\x00\x00\x00\x00\x00", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, } - if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when - // test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by - // default. - opts.AutoGenIPv6LinkLocal = true + e := channel.New(0, 1280, test.linkAddr) + s := stack.New(opts) + nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - e := channel.New(10, 1280, test.linkAddr) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + // A new disabled NIC should not have any address, even if auto generation + // was enabled. + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) } - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + // Enabling the NIC should attempt auto-generation of a link-local + // address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } + var expectedMainAddr tcpip.AddressWithPrefix if test.shouldGen { - // Should have auto-generated an address and - // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + expectedMainAddr = tcpip.AddressWithPrefix{ + Address: test.expectedAddr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + } + + // Should have auto-generated an address and resolved immediately (DAD + // is disabled). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } } else { // Should not have auto-generated an address. - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address") + default: } } + + gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if gotMainAddr != expectedMainAddr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) + } + }) + } +} + +// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are +// not auto-generated for loopback NICs. +func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { + const nicID = 1 + const nicName = "nicName" + + tests := []struct { + name string + opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + }{ + { + name: "IID From MAC", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + }, + { + name: "Opaque IID", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + } + + e := loopback.New() + s := stack.New(opts) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + } }) } } @@ -1971,47 +2673,56 @@ func TestNICAutoGenAddr(t *testing.T) { // TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 // link-local addresses will only be assigned after the DAD process resolves. func TestNICAutoGenAddrDoesDAD(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 1, - }, + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } // Address should not be considered bound to the // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - // Wait for the address to resolve (an extra - // 250ms to make sure the address resolves). - // - // TODO(b/140896005): Use events from the - // netstack to know immediately when DAD - // completes. - time.Sleep(time.Second + 250*time.Millisecond) + linkLocalAddr := header.LinkLocalAddr(linkAddr1) - // Should have auto-generated an address and - // resolved (if DAD). - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + } + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(linkAddr1), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } } @@ -2059,7 +2770,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { - t.Fatalf("NewSubnet failed:", err) + t.Fatalf("NewSubnet failed: %v", err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } @@ -2073,11 +2784,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // permanentExpired kind. r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) if err != nil { - t.Fatal("FindRoute failed:", err) + t.Fatalf("FindRoute failed: %v", err) } defer r.Release() if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } // @@ -2089,7 +2800,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } @@ -2097,7 +2808,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // make sure the new peb was respected. // (The address should just be promoted now). if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[1].ProtocolAddresses { @@ -2130,11 +2841,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // GetMainNICAddress; else, our original address // should be returned. if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } addr, err = s.GetMainNICAddress(1, fakeNetNumber) if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + t.Fatalf("s.GetMainNICAddress failed: %v", err) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { @@ -2150,3 +2861,420 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { } } } + +func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { + const ( + linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + nicID = 1 + ) + + // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. + tests := []struct { + name string + nicAddrs []tcpip.Address + connectAddr tcpip.Address + expectedLocalAddr tcpip.Address + }{ + // Test Rule 1 of RFC 6724 section 5. + { + name: "Same Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Same Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test Rule 2 of RFC 6724 section 5. + { + name: "Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred for link local multicast (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred for link local multicast (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test returning the endpoint that is closest to the front when + // candidate addresses are "equal" from the perspective of RFC 6724 + // section 5. + { + name: "Unique Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Link Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local for Unique Local", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + + for _, a := range test.nicAddrs { + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { + t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + } + } + + if t.Failed() { + t.FailNow() + } + + if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) + } + }) + } +} + +func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + + // Enabling the NIC should add the IPv4 broadcast address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 1 { + t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) + } + want := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } + if allNICAddrs[0] != want { + t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + } + + // Disabling the NIC should remove the IPv4 broadcast address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } +} + +// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 +// address after leaving its solicited node multicast address does not result in +// an error. +func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + } + + // The NIC should have joined addr1's solicited node multicast address. + snmc := header.SolicitedNodeAddr(addr1) + in, err := s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if !in { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) + } + + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) + } + in, err = s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if in { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) + } + + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + } +} + +func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + // Should not be in the IPv6 all-nodes multicast group yet because the NIC has + // not been enabled yet. + isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } +} + +// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC +// was disabled have DAD performed on them when the NIC is enabled. +func TestDoDADWhenNICEnabled(t *testing.T) { + const dadTransmits = 1 + const retransmitTimer = time.Second + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPDisp: &ndpDisp, + } + + e := channel.New(dadTransmits, 1280, linkAddr1) + s := stack.New(opts) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: llAddr1, + PrefixLen: 128, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + // Address should be in the list of all addresses. + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + + // Address should be tentative so it should not be a main address. + got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); got != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + } + + // Enabling the NIC should start DAD for the address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); got != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if got != addr.AddressWithPrefix { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + } + + // Enabling the NIC again should be a no-op. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if got != addr.AddressWithPrefix { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + } +} |