diff options
Diffstat (limited to 'pkg/tcpip/stack/stack_test.go')
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 376 |
1 files changed, 255 insertions, 121 deletions
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 959071dbe..9d082bba4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "sort" "strings" "testing" @@ -32,8 +33,9 @@ import ( ) const ( - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fakeNetHeaderLen = 12 + fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fakeNetHeaderLen = 12 + fakeDefaultPrefixLen = 8 // fakeControlProtocol is used for control packets that represent // destination port unreachable. @@ -55,6 +57,7 @@ const ( type fakeNetworkEndpoint struct { nicid tcpip.NICID id stack.NetworkEndpointID + prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher linkEP stack.LinkEndpoint @@ -68,6 +71,10 @@ func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { return f.nicid } +func (f *fakeNetworkEndpoint) PrefixLen() int { + return f.prefixLen +} + func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } @@ -170,14 +177,19 @@ func (f *fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } +func (f *fakeNetworkProtocol) DefaultPrefixLen() int { + return fakeDefaultPrefixLen +} + func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, - id: stack.NetworkEndpointID{addr}, + id: stack.NetworkEndpointID{addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, linkEP: linkEP, @@ -212,15 +224,15 @@ func TestNetworkReceive(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -280,13 +292,13 @@ func TestNetworkReceive(t *testing.T) { func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatal("FindRoute failed:", err) } defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Errorf("WritePacket failed: %v", err) + t.Error("WritePacket failed:", err) } } @@ -297,13 +309,13 @@ func TestNetworkSend(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("NewNIC failed: %v", err) + t.Fatal("NewNIC failed:", err) } s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } // Make sure that the link-layer endpoint received the outbound packet. @@ -321,28 +333,28 @@ func TestNetworkSendMultiRoute(t *testing.T) { id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } id2, linkEP2 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } // Set a route table that sends all packets with odd destination @@ -371,7 +383,7 @@ func TestNetworkSendMultiRoute(t *testing.T) { func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatal("FindRoute failed:", err) } defer r.Release() @@ -388,7 +400,7 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err) + t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute) } } @@ -400,28 +412,28 @@ func TestRoutes(t *testing.T) { id1, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } id2, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } // Set a route table that sends all packets with odd destination @@ -464,11 +476,11 @@ func TestAddressRemoval(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -486,7 +498,7 @@ func TestAddressRemoval(t *testing.T) { // Remove the address, then check that packet doesn't get delivered // anymore. if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + t.Fatal("RemoveAddress failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -496,7 +508,7 @@ func TestAddressRemoval(t *testing.T) { // Check that removing the same address fails. if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress failed: %v", err) + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } } @@ -505,11 +517,11 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -531,7 +543,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { // Get a route, check that packet is still deliverable. r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatal("FindRoute failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -542,7 +554,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { // Remove the address, then check that packet is still deliverable // because the route is keeping the address alive. if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + t.Fatal("RemoveAddress failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -552,7 +564,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { // Check that removing the same address fails. if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress failed: %v", err) + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } // Release the route, then check that packet is not deliverable anymore. @@ -568,7 +580,7 @@ func TestPromiscuousMode(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -590,7 +602,7 @@ func TestPromiscuousMode(t *testing.T) { // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { - t.Fatalf("SetPromiscuousMode failed: %v", err) + t.Fatal("SetPromiscuousMode failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -601,13 +613,13 @@ func TestPromiscuousMode(t *testing.T) { // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err) + t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute) } // Set promiscuous mode to false, then check that packet can't be // delivered anymore. if err := s.SetPromiscuousMode(1, false); err != nil { - t.Fatalf("SetPromiscuousMode failed: %v", err) + t.Fatal("SetPromiscuousMode failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -624,11 +636,11 @@ func TestAddressSpoofing(t *testing.T) { id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -645,11 +657,11 @@ func TestAddressSpoofing(t *testing.T) { // With address spoofing enabled, FindRoute permits any address to be used // as the source. if err := s.SetSpoofing(1, true); err != nil { - t.Fatalf("SetSpoofing failed: %v", err) + t.Fatal("SetSpoofing failed:", err) } r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatal("FindRoute failed:", err) } if r.LocalAddress != srcAddr { t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) @@ -664,17 +676,17 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, header.IPv4Any, err) + t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { @@ -735,7 +747,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) @@ -791,7 +803,7 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -806,10 +818,10 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { fakeNet.packetCount[1] = 0 subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + t.Fatal("NewSubnet failed:", err) } if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + t.Fatal("AddSubnet failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -824,7 +836,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -839,10 +851,10 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { fakeNet.packetCount[1] = 0 subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + t.Fatal("NewSubnet failed:", err) } if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + t.Fatal("AddSubnet failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { @@ -894,38 +906,38 @@ func TestSubnetAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } addr := tcpip.Address("\x01\x01\x01\x01") mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) subnet, err := tcpip.NewSubnet(addr, mask) if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + t.Fatal("NewSubnet failed:", err) } if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) + t.Fatal("ContainsSubnet failed:", err) } else if contained { t.Fatal("got s.ContainsSubnet(...) = true, want = false") } if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + t.Fatal("AddSubnet failed:", err) } if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) + t.Fatal("ContainsSubnet failed:", err) } else if !contained { t.Fatal("got s.ContainsSubnet(...) = false, want = true") } if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatalf("RemoveSubnet failed: %v", err) + t.Fatal("RemoveSubnet failed:", err) } if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) + t.Fatal("ContainsSubnet failed:", err) } else if contained { t.Fatal("got s.ContainsSubnet(...) = true, want = false") } @@ -941,11 +953,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } // Insert <canBe> primary and <never> never-primary addresses. // Each one will add a network endpoint to the NIC. - primaryAddrAdded := make(map[tcpip.Address]tcpip.Subnet) + primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{}) for i := 0; i < canBe+never; i++ { var behavior stack.PrimaryEndpointBehavior if i < canBe { @@ -953,46 +965,39 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } else { behavior = stack.NeverPrimaryEndpoint } - // Add an address and in case of a primary one also add a - // subnet. + // Add an address and in case of a primary one include a + // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) - if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) - } if behavior == stack.CanBePrimaryEndpoint { - mask := tcpip.AddressMask(strings.Repeat("\xff", len(address))) - subnet, err := tcpip.NewSubnet(address, mask) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + addressWithPrefix := tcpip.AddressWithPrefix{address, addrLen * 8} + if err := s.AddAddressWithPrefixAndOptions(1, fakeNetNumber, addressWithPrefix, behavior); err != nil { + t.Fatal("AddAddressWithPrefixAndOptions failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + // Remember the address/prefix. + primaryAddrAdded[addressWithPrefix] = struct{}{} + } else { + if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) } - // Remember the address/subnet. - primaryAddrAdded[address] = subnet } } // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the - // address/subnet matches what we added. + // address/prefixLen matches what we added. if len(primaryAddrAdded) == 0 { // No primary addresses present, expect an error. - if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %v", err, tcpip.ErrNoLinkAddress) + if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress) } } else { // At least one primary address was added, expect a valid - // address and subnet. - gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber) + // address and prefixLen. + gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber) if err != nil { - t.Fatalf("GetMainNICAddress failed: %v", err) - } - expectedSubnet, ok := primaryAddrAdded[gotAddress] - if !ok { - t.Fatalf("GetMainNICAddress: got address = %v, wanted any in {%v}", gotAddress, primaryAddrAdded) + t.Fatal("GetMainNICAddress failed:", err) } - if gotSubnet != expectedSubnet { - t.Fatalf("GetMainNICAddress: got subnet = %v, wanted %v", gotSubnet, expectedSubnet) + if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok { + t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded) } } }) @@ -1007,65 +1012,194 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } for _, tc := range []struct { - name string - address tcpip.Address + name string + address tcpip.Address + prefixLen int }{ - {"IPv4", "\x01\x01\x01\x01"}, - {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, + {"IPv4", "\x01\x01\x01\x01", 24}, + {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, } { t.Run(tc.name, func(t *testing.T) { - address := tc.address - mask := tcpip.AddressMask(strings.Repeat("\xff", len(address))) - subnet, err := tcpip.NewSubnet(address, mask) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + addressWithPrefix := tcpip.AddressWithPrefix{tc.address, tc.prefixLen} + + if err := s.AddAddressWithPrefix(1, fakeNetNumber, addressWithPrefix); err != nil { + t.Fatal("AddAddressWithPrefix failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress failed: %v", err) + // Check that we get the right initial address and prefix length. + if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } else if gotAddressWithPrefix != addressWithPrefix { + t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, addressWithPrefix) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + if err := s.RemoveAddress(1, addressWithPrefix.Address); err != nil { + t.Fatal("RemoveAddress failed:", err) } - // Check that we get the right initial address and subnet. - if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress failed: %v", err) - } else if gotAddress != address { - t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address) - } else if gotSubnet != subnet { - t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet) + // Check that we get an error after removal. + if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress) } + }) + } +} + +// Simple network address generator. Good for 255 addresses. +type addressGenerator struct{ cnt byte } + +func (g *addressGenerator) next(addrLen int) tcpip.Address { + g.cnt++ + return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen)) +} + +func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { + if len(gotAddresses) != len(expectedAddresses) { + t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses)) + } + + sort.Slice(gotAddresses, func(i, j int) bool { + return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address + }) + sort.Slice(expectedAddresses, func(i, j int) bool { + return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address + }) + + for i, gotAddr := range gotAddresses { + expectedAddr := expectedAddresses[i] + if gotAddr != expectedAddr { + t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr) + } + } +} + +func TestAddAddress(t *testing.T) { + const nicid = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatalf("RemoveSubnet failed: %v", err) + var addrGen addressGenerator + expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) + for _, addrLen := range []int{4, 16} { + address := addrGen.next(addrLen) + if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + t.Fatalf("AddAddress(address=%s) failed: %s", address, err) + } + expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + }) + } + + gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + verifyAddresses(t, expectedAddresses, gotAddresses) +} + +func TestAddAddressWithPrefix(t *testing.T) { + const nicid = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + var addrGen addressGenerator + addrLenRange := []int{4, 16} + prefixLenRange := []int{8, 13, 20, 32} + expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) + for _, addrLen := range addrLenRange { + for _, prefixLen := range prefixLenRange { + address := addrGen.next(addrLen) + if err := s.AddAddressWithPrefix(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}); err != nil { + t.Errorf("AddAddressWithPrefix(address=%s, prefixLen=%d) failed: %s", address, prefixLen, err) } + expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen}, + }) + } + } - if err := s.RemoveAddress(1, address); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + verifyAddresses(t, expectedAddresses, gotAddresses) +} + +func TestAddAddressWithOptions(t *testing.T) { + const nicid = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + addrLenRange := []int{4, 16} + behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} + expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) + var addrGen addressGenerator + for _, addrLen := range addrLenRange { + for _, behavior := range behaviorRange { + address := addrGen.next(addrLen) + if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } + expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + }) + } + } - // Check that we get an error after removal. - if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress) + gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + verifyAddresses(t, expectedAddresses, gotAddresses) +} + +func TestAddAddressWithPrefixAndOptions(t *testing.T) { + const nicid = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + addrLenRange := []int{4, 16} + prefixLenRange := []int{8, 13, 20, 32} + behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} + expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) + var addrGen addressGenerator + for _, addrLen := range addrLenRange { + for _, prefixLen := range prefixLenRange { + for _, behavior := range behaviorRange { + address := addrGen.next(addrLen) + if err := s.AddAddressWithPrefixAndOptions(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}, behavior); err != nil { + t.Fatalf("AddAddressWithPrefixAndOptions(address=%s, prefixLen=%d, behavior=%d) failed: %s", address, prefixLen, behavior, err) + } + expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen}, + }) } - }) + } } + + gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + verifyAddresses(t, expectedAddresses, gotAddresses) } func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } // Route all packets for address \x01 to NIC 1. s.SetRouteTable([]tcpip.Route{ @@ -1104,18 +1238,18 @@ func TestNICForwarding(t *testing.T) { id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) + t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) + t.Fatal("AddAddress #1 failed:", err) } id2, linkEP2 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) + t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) + t.Fatal("AddAddress #2 failed:", err) } // Route all packets to address 3 to NIC 2. |