summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go310
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go290
2 files changed, 522 insertions, 78 deletions
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 21a8dd291..b56706357 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type inputIfNameMatcher struct {
@@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) {
})
}
}
+
+var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil)
+
+// channelEndpointWithoutWritePacket is a channel endpoint that does not support
+// stack.LinkEndpoint.WritePacket.
+type channelEndpointWithoutWritePacket struct {
+ *channel.Endpoint
+
+ t *testing.T
+}
+
+func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+ c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
+ return &tcpip.ErrNotSupported{}
+}
+
+var _ stack.Matcher = (*udpSourcePortMatcher)(nil)
+
+type udpSourcePortMatcher struct {
+ port uint16
+}
+
+func (*udpSourcePortMatcher) Name() string {
+ return "udpSourcePortMatcher"
+}
+
+func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
+ udp := header.UDP(pkt.TransportHeader().View())
+ if len(udp) < header.UDPMinimumSize {
+ // Drop immediately as the packet is invalid.
+ return false, true
+ }
+
+ return udp.SourcePort() == m.port, false
+}
+
+func TestIPTableWritePackets(t *testing.T) {
+ const (
+ nicID = 1
+
+ dropLocalPort = localPort - 1
+ acceptPackets = 2
+ dropPackets = 3
+ )
+
+ udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
+ u := header.UDP(hdr)
+ u.Encode(&header.UDPFields{
+ SrcPort: srcPort,
+ DstPort: dstPort,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
+ sum = header.Checksum(hdr, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ }
+
+ tests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack)
+ genPacket func(*stack.Route) stack.PacketBufferList
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectSent uint64
+ expectOutputDropped uint64
+ }{
+ {
+ name: "IPv4 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv4 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ {
+ name: "IPv6 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channelEndpointWithoutWritePacket{
+ Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
+ t: t,
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ test.setupFilter(t, s)
+
+ r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
+ }
+ defer r.Release()
+
+ pkts := test.genPacket(r)
+ pktsLen := pkts.Len()
+ if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{
+ Protocol: header.UDPProtocolNumber,
+ TTL: 64,
+ }); err != nil {
+ t.Fatalf("WritePackets(...): %s", err)
+ } else if n != pktsLen {
+ t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen)
+ }
+
+ if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
+ t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
+ }
+ if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
+ t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index f2301a9e6..824f81a42 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -487,30 +487,25 @@ func TestGetLinkAddress(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- for _, useNeighborCache := range []bool{true, false} {
- t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- UseNeighborCache: useNeighborCache,
- }
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ }
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
- ch := make(chan stack.LinkResolutionResult, 1)
- err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
- }
- wantRes := stack.LinkResolutionResult{Success: test.expectedOk}
- if test.expectedOk {
- wantRes.LinkAddress = linkAddr2
- }
- if diff := cmp.Diff(wantRes, <-ch); diff != "" {
- t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
+ }
+ wantRes := stack.LinkResolutionResult{Success: test.expectedOk}
+ if test.expectedOk {
+ wantRes.LinkAddress = linkAddr2
+ }
+ if diff := cmp.Diff(wantRes, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
}
})
}
@@ -587,66 +582,61 @@ func TestRouteResolvedFields(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- for _, useNeighborCache := range []bool{true, false} {
- t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- UseNeighborCache: useNeighborCache,
- }
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ }
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
- r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
- }
- defer r.Release()
-
- var wantRouteInfo stack.RouteInfo
- wantRouteInfo.LocalLinkAddress = linkAddr1
- wantRouteInfo.LocalAddress = test.localAddr
- wantRouteInfo.RemoteAddress = test.remoteAddr
- wantRouteInfo.NetProto = test.netProto
- wantRouteInfo.Loop = stack.PacketOut
- wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
-
- ch := make(chan stack.ResolvedFieldsResult, 1)
-
- if !test.immediatelyResolvable {
- wantUnresolvedRouteInfo := wantRouteInfo
- wantUnresolvedRouteInfo.RemoteLinkAddress = ""
-
- err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
- }
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
- }
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
- if !test.expectedSuccess {
- return
- }
+ var wantRouteInfo stack.RouteInfo
+ wantRouteInfo.LocalLinkAddress = linkAddr1
+ wantRouteInfo.LocalAddress = test.localAddr
+ wantRouteInfo.RemoteAddress = test.remoteAddr
+ wantRouteInfo.NetProto = test.netProto
+ wantRouteInfo.Loop = stack.PacketOut
+ wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
- // At this point the neighbor table should be populated so the route
- // should be immediately resolvable.
- }
+ ch := make(chan stack.ResolvedFieldsResult, 1)
- if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
- ch <- r
- }); err != nil {
- t.Errorf("r.ResolvedFields(_): %s", err)
- }
- select {
- case routeResolveRes := <-ch:
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected route to be immediately resolvable")
- }
+ if !test.immediatelyResolvable {
+ wantUnresolvedRouteInfo := wantRouteInfo
+ wantUnresolvedRouteInfo.RemoteLinkAddress = ""
+
+ err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
+ }
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
+ }
+
+ if !test.expectedSuccess {
+ return
+ }
+
+ // At this point the neighbor table should be populated so the route
+ // should be immediately resolvable.
+ }
+
+ if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
+ }); err != nil {
+ t.Errorf("r.ResolvedFields(_): %s", err)
+ }
+ select {
+ case routeResolveRes := <-ch:
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected route to be immediately resolvable")
}
})
}
@@ -1065,7 +1055,6 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
Clock: clock,
- UseNeighborCache: true,
}
host1StackOpts := stackOpts
host1StackOpts.NUDDisp = &nudDisp
@@ -1210,3 +1199,148 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
})
}
}
+
+func TestDAD(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ dadConfigs := stack.DADConfigurations{
+ DupAddrDetectTransmits: 1,
+ RetransmitTimer: time.Second,
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ dadNetProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedResolved bool
+ }{
+ {
+ name: "IPv4 own address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: ipv4Addr1.AddressWithPrefix.Address,
+ expectedResolved: true,
+ },
+ {
+ name: "IPv6 own address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr1.AddressWithPrefix.Address,
+ expectedResolved: true,
+ },
+ {
+ name: "IPv4 duplicate address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedResolved: false,
+ },
+ {
+ name: "IPv6 duplicate address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedResolved: false,
+ },
+ {
+ name: "IPv4 no duplicate address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ expectedResolved: true,
+ },
+ {
+ name: "IPv6 no duplicate address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ expectedResolved: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ stackOpts := stack.Options{
+ Clock: clock,
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ arp.NewProtocol,
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ }
+
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ // DAD should be disabled by default.
+ if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
+ t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled")
+ }); err != nil {
+ t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err)
+ } else if res != stack.DADDisabled {
+ t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled)
+ }
+
+ // Enable DAD then attempt to check if an address is duplicated.
+ netEP, err := host1Stack.GetNetworkEndpoint(host1NICID, test.dadNetProto)
+ if err != nil {
+ t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", host1NICID, test.dadNetProto, err)
+ }
+ dad, ok := netEP.(stack.DuplicateAddressDetector)
+ if !ok {
+ t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP)
+ }
+ dad.SetDADConfigurations(dadConfigs)
+ ch := make(chan stack.DADResult, 3)
+ if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err)
+ } else if res != stack.DADStarting {
+ t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting)
+ }
+
+ expectResults := 1
+ if test.expectedResolved {
+ const delta = time.Nanosecond
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta)
+ select {
+ case r := <-ch:
+ t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r)
+ default:
+ }
+
+ // If we expect the resolve to succeed try requesting DAD again on the
+ // same address. The handler for the new request should be called once
+ // the original DAD request completes.
+ expectResults = 2
+ if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err)
+ } else if res != stack.DADAlreadyRunning {
+ t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning)
+ }
+
+ clock.Advance(delta)
+ }
+
+ for i := 0; i < expectResults; i++ {
+ if diff := cmp.Diff(stack.DADResult{Resolved: test.expectedResolved}, <-ch); diff != "" {
+ t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ // Should have no more results.
+ select {
+ case r := <-ch:
+ t.Errorf("unexpectedly got an extra DAD result; r = %#v", r)
+ default:
+ }
+ })
+ }
+}