diff options
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 314 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 538 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/loopback_test.go | 31 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/route_test.go | 21 |
6 files changed, 710 insertions, 206 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 218b218e7..71695b630 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -17,6 +17,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index aedf1845e..38e1881c7 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -38,96 +38,207 @@ import ( var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil) var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil) -// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner -// link endpoint and checks the destination link address before delivering -// network packets to the network dispatcher. -// -// See ethernet.Endpoint for more details. -func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { - var e endpointWithDestinationCheck - e.Endpoint.Init(ethernet.New(ep), &e) - return &e -} - -// endpointWithDestinationCheck is a link endpoint that checks the destination -// link address before delivering network packets to the network dispatcher. -type endpointWithDestinationCheck struct { - nested.Endpoint -} - -// DeliverNetworkPacket implements stack.NetworkDispatcher. -func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) - } -} - -func TestForwarding(t *testing.T) { - const ( - host1NICID = 1 - routerNICID1 = 2 - routerNICID2 = 3 - host2NICID = 4 - - listenPort = 8080 - ) +const ( + host1NICID = 1 + routerNICID1 = 2 + routerNICID2 = 3 + host2NICID = 4 +) - host1IPv4Addr := tcpip.ProtocolAddress{ +var ( + host1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 24, }, } - routerNIC1IPv4Addr := tcpip.ProtocolAddress{ + routerNIC1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - routerNIC2IPv4Addr := tcpip.ProtocolAddress{ + routerNIC2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, }, } - host2IPv4Addr := tcpip.ProtocolAddress{ + host2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr := tcpip.ProtocolAddress{ + host1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), PrefixLen: 64, }, } - routerNIC1IPv6Addr := tcpip.ProtocolAddress{ + routerNIC1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - routerNIC2IPv6Addr := tcpip.ProtocolAddress{ + routerNIC2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr := tcpip.ProtocolAddress{ + host2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::2").To16()), PrefixLen: 64, }, } +) + +func setupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) { + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + { + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + { + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) +} + +// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner +// link endpoint and checks the destination link address before delivering +// network packets to the network dispatcher. +// +// See ethernet.Endpoint for more details. +func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { + var e endpointWithDestinationCheck + e.Endpoint.Init(ethernet.New(ep), &e) + return &e +} + +// endpointWithDestinationCheck is a link endpoint that checks the destination +// link address before delivering network packets to the network dispatcher. +type endpointWithDestinationCheck struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) + } +} + +func TestForwarding(t *testing.T) { + const listenPort = 8080 type endpointAndAddresses struct { serverEP tcpip.Endpoint @@ -229,7 +340,7 @@ func TestForwarding(t *testing.T) { subTests := []struct { name string proto tcpip.TransportProtocolNumber - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) needRemoteAddr bool }{ @@ -250,7 +361,7 @@ func TestForwarding(t *testing.T) { { name: "TCP", proto: tcp.ProtocolNumber, - expectedConnectErr: tcpip.ErrConnectStarted, + expectedConnectErr: &tcpip.ErrConnectStarted{}, setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() @@ -260,7 +371,7 @@ func TestForwarding(t *testing.T) { var addr tcpip.FullAddress for { newEP, wq, err := ep.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-ch continue } @@ -294,113 +405,7 @@ func TestForwarding(t *testing.T) { host1Stack := stack.New(stackOpts) routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - - host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) - routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) - } - if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - { - Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - }) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) defer epsAndAddrs.serverEP.Close() @@ -415,8 +420,11 @@ func TestForwarding(t *testing.T) { t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) } - if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr { - t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr) + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } } if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index f85164c5b..f2301a9e6 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -19,12 +19,14 @@ import ( "fmt" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -245,6 +247,14 @@ func TestPing(t *testing.T) { } } +type transportError struct { + origin tcpip.SockErrOrigin + typ uint8 + code uint8 + info uint32 + kind stack.TransportErrorKind +} + func TestTCPLinkResolutionFailure(t *testing.T) { const ( host1NICID = 1 @@ -255,8 +265,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name string netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error sockError tcpip.SockError + transErr transportError }{ { name: "IPv4 with resolvable remote", @@ -274,12 +285,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name: "IPv4 without resolvable remote", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, - ErrType: byte(header.ICMPv4DstUnreachable), - ErrCode: byte(header.ICMPv4HostUnreachable), - ErrOrigin: tcpip.SockExtErrorOriginICMP, + Err: &tcpip.ErrNoRoute{}, Dst: tcpip.FullAddress{ NIC: host1NICID, Addr: ipv4Addr3.AddressWithPrefix.Address, @@ -291,17 +299,20 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }, NetProto: ipv4.ProtocolNumber, }, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4HostUnreachable), + kind: stack.DestinationHostUnreachableTransportError, + }, }, { name: "IPv6 without resolvable remote", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, - ErrType: byte(header.ICMPv6DstUnreachable), - ErrCode: byte(header.ICMPv6AddressUnreachable), - ErrOrigin: tcpip.SockExtErrorOriginICMP6, + Err: &tcpip.ErrNoRoute{}, Dst: tcpip.FullAddress{ NIC: host1NICID, Addr: ipv6Addr3.AddressWithPrefix.Address, @@ -313,6 +324,12 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }, NetProto: ipv6.ProtocolNumber, }, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6DstUnreachable), + code: uint8(header.ICMPv6AddressUnreachable), + kind: stack.DestinationHostUnreachableTransportError, + }, }, } @@ -355,18 +372,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr := listenerAddr remoteAddr.Addr = test.remoteAddr - if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted) + { + err := clientEP.Connect(remoteAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{}) + } } // Wait for an error due to link resolution failing, or the endpoint to be // writable. <-ch - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if n, err := clientEP.Write(&r, wOpts); err != test.expectedWriteErr { - t.Errorf("got clientEP.Write(_, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + _, err := clientEP.Write(&r, wOpts) + if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" { + t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff) + } } if test.expectedWriteErr == nil { @@ -380,14 +403,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) { sockErrCmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(tcpip.SockError{}), - cmp.Comparer(func(a, b *tcpip.Error) bool { + cmp.Comparer(func(a, b tcpip.Error) bool { // tcpip.Error holds an unexported field but the errors netstack uses // are pre defined so we can simply compare pointers. return a == b }), - // Ignore the payload since we do not know the TCP seq/ack numbers. checker.IgnoreCmpPath( + // Ignore the payload since we do not know the TCP seq/ack numbers. "Payload", + // Ignore the cause since we will compare its properties separately + // since the concrete type of the cause is unknown. + "Cause", ), } @@ -399,6 +425,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { t.Errorf("socket error mismatch (-want +got):\n%s", diff) } + + transErr, ok := sockErr.Cause.(stack.TransportError) + if !ok { + t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause) + } + if diff := cmp.Diff( + test.transErr, + transportError{ + origin: transErr.Origin(), + typ: transErr.Type(), + code: transErr.Code(), + info: transErr.Info(), + kind: transErr.Kind(), + }, + cmp.AllowUnexported(transportError{}), + ); diff != "" { + t.Errorf("socket error mismatch (-want +got):\n%s", diff) + } }) } } @@ -453,10 +497,11 @@ func TestGetLinkAddress(t *testing.T) { host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) ch := make(chan stack.LinkResolutionResult, 1) - if err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { ch <- r - }); err != tcpip.ErrWouldBlock { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, tcpip.ErrWouldBlock) + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) } wantRes := stack.LinkResolutionResult{Success: test.expectedOk} if test.expectedOk { @@ -570,10 +615,11 @@ func TestRouteResolvedFields(t *testing.T) { wantUnresolvedRouteInfo := wantRouteInfo wantUnresolvedRouteInfo.RemoteLinkAddress = "" - if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { ch <- r - }); err != tcpip.ErrWouldBlock { - t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, tcpip.ErrWouldBlock) + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) @@ -616,7 +662,7 @@ func TestWritePacketsLinkResolution(t *testing.T) { name string netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error }{ { name: "IPv4", @@ -703,7 +749,7 @@ func TestWritePacketsLinkResolution(t *testing.T) { var rOpts tcpip.ReadOptions res, err := serverEP.Read(&writer, rOpts) if err != nil { - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Should not have anymore bytes to read after we read the sent // number of bytes. if count == len(data) { @@ -728,3 +774,439 @@ func TestWritePacketsLinkResolution(t *testing.T) { }) } } + +type eventType int + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved +) + +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + entry stack.NeighborEntry +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) +} + +var _ stack.NUDDispatcher = (*nudDispatcher)(nil) + +type nudDispatcher struct { + c chan eventInfo +} + +func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) waitForEvent(want eventInfo) error { + if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil +} + +// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it +// that the neighbor used for a route is reachable. +func TestTCPConfirmNeighborReachability(t *testing.T) { + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + neighborAddr tcpip.Address + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + isHost1Listener bool + }{ + { + name: "IPv4 active connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 active connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 passive connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv4 passive connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + nudDisp := nudDispatcher{ + c: make(chan eventInfo, 3), + } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, + UseNeighborCache: true, + } + host1StackOpts := stackOpts + host1StackOpts.NUDDisp = &nudDisp + + host1Stack := stack.New(host1StackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + // Add a reachable dynamic entry to our neighbor table for the remote. + { + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: linkAddr2, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryAdded, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, + }); err != nil { + t.Fatalf("error waiting for initial NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the remote's neighbor entry to be stale before creating a + // TCP connection from host1 to some remote. + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + // The maximum reachable time for a neighbor is some maximum random factor + // applied to the base reachable time. + // + // See NUDConfigurations.BaseReachableTime for more information. + maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + + listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) + defer listenerEP.Close() + defer clientEP.Close() + listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} + if err := listenerEP.Bind(listenerAddr); err != nil { + t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) + } + if err := listenerEP.Listen(1); err != nil { + t.Fatalf("listenerEP.Listen(1): %s", err) + } + { + err := clientEP.Connect(listenerAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{}) + } + } + + // Wait for the TCP handshake to complete then make sure the neighbor is + // reachable without entering the probe state as TCP should provide NUD + // with confirmation that the neighbor is reachable (indicated by a + // successful 3-way handshake). + <-clientCH + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the neighbor to be stale again then send data to the remote. + // + // On successful transmission, the neighbor should become reachable + // without probing the neighbor as a TCP ACK would be received which is an + // indication of the neighbor being reachable. + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if test.isHost1Listener { + // If host1 is not the client, host1 does not send any data so TCP + // has no way to know it is making forward progress. Because of this, + // TCP should not mark the route reachable and NUD should go through the + // probe state. + clock.Advance(nudConfigs.DelayFirstProbeTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for probe NUD event: %s", err) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 761283b66..ab67762ef 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -37,7 +37,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) type ndpDispatcher struct{} -func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { } func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { @@ -262,8 +262,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -322,11 +322,14 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) } - if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + { + err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) + } } } @@ -470,13 +473,17 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { Addr: test.dstAddr, Port: localPort, } - if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + { + err := connectingEndpoint.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + } } if !test.expectAccept { - if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + _, _, err := listeningEndpoint.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } return } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 9cc12fa58..d685fdd36 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -479,8 +479,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -761,8 +761,11 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { if err := ep.SetSockOpt(&removeOpt); err != nil { t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) } - if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) + { + _, err := ep.Read(&buf, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{}) + } } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 35ee7437a..9654c9527 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -81,7 +81,7 @@ func TestLocalPing(t *testing.T) { linkEndpoint func() stack.LinkEndpoint localAddr tcpip.Address icmpBuf func(*testing.T) buffer.View - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) }{ { @@ -126,7 +126,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -135,7 +135,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -144,7 +144,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, { @@ -153,7 +153,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, } @@ -186,8 +186,11 @@ func TestLocalPing(t *testing.T) { defer ep.Close() connAddr := tcpip.FullAddress{Addr: test.localAddr} - if err := ep.Connect(connAddr); err != test.expectedConnectErr { - t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + { + err := ep.Connect(connAddr) + if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) + } } if test.expectedConnectErr != nil { @@ -263,12 +266,12 @@ func TestLocalUDP(t *testing.T) { subTests := []struct { name string addAddress bool - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error }{ { name: "Unassigned local address", addAddress: false, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, }, { name: "Assigned local address", |