diff options
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 301 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 434 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 11 |
9 files changed, 633 insertions, 162 deletions
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 204196d00..eea32dcf5 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -297,10 +297,9 @@ func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Li // no matching entry for the remote address. } -// HandleUpperLevelConfirmation implements -// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in -// RFC 4861 section 7.3.1. -func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { +// handleUpperLevelConfirmation processes a confirmation of reachablity from +// some protocol that operates at a layer above the IP/link layer. +func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 1bbfe6213..f59416fd3 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -561,6 +561,12 @@ func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress } +func (n *NIC) confirmReachable(addr tcpip.Address) { + if n := n.neigh; n != nil { + n.handleUpperLevelConfirmation(addr) + } +} + func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if n.neigh != nil { entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 12d67409a..77926e289 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -174,10 +174,6 @@ type NUDHandler interface { // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) - - // HandleUpperLevelConfirmation processes an incoming upper-level protocol - // (e.g. TCP acknowledgements) reachability confirmation. - HandleUpperLevelConfirmation(addr tcpip.Address) } // NUDConfigurations is the NUD configurations for the netstack. This is used diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index d9a8554e2..9c8c155fa 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -354,11 +354,6 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn return fields, nil, nil } - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress - } - // If specified, the local address used for link address resolution must be an // address on the outgoing interface. var linkAddressResolutionRequestLocalAddr tcpip.Address @@ -367,7 +362,7 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn } afterResolveFields := fields - linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { + linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { if afterResolve != nil { if r.Success { afterResolveFields.RemoteLinkAddress = r.LinkAddress @@ -382,6 +377,13 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn return fields, ch, err } +func (r *Route) nextHop() tcpip.Address { + if len(r.NextHop) == 0 { + return r.RemoteAddress + } + return r.NextHop +} + // local returns true if the route is a local route. func (r *Route) local() bool { return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() @@ -519,3 +521,12 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } + +// ConfirmReachable informs the network/link layer that the neighbour used for +// the route is reachable. +// +// "Reachable" is defined as having full-duplex communication between the +// local and remote ends of the route. +func (r *Route) ConfirmReachable() { + r.outgoingNIC.confirmReachable(r.nextHop()) +} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 218b218e7..71695b630 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -17,6 +17,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index aedf1845e..2b5832c62 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -38,96 +38,207 @@ import ( var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil) var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil) -// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner -// link endpoint and checks the destination link address before delivering -// network packets to the network dispatcher. -// -// See ethernet.Endpoint for more details. -func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { - var e endpointWithDestinationCheck - e.Endpoint.Init(ethernet.New(ep), &e) - return &e -} - -// endpointWithDestinationCheck is a link endpoint that checks the destination -// link address before delivering network packets to the network dispatcher. -type endpointWithDestinationCheck struct { - nested.Endpoint -} - -// DeliverNetworkPacket implements stack.NetworkDispatcher. -func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) - } -} - -func TestForwarding(t *testing.T) { - const ( - host1NICID = 1 - routerNICID1 = 2 - routerNICID2 = 3 - host2NICID = 4 - - listenPort = 8080 - ) +const ( + host1NICID = 1 + routerNICID1 = 2 + routerNICID2 = 3 + host2NICID = 4 +) - host1IPv4Addr := tcpip.ProtocolAddress{ +var ( + host1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 24, }, } - routerNIC1IPv4Addr := tcpip.ProtocolAddress{ + routerNIC1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - routerNIC2IPv4Addr := tcpip.ProtocolAddress{ + routerNIC2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, }, } - host2IPv4Addr := tcpip.ProtocolAddress{ + host2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr := tcpip.ProtocolAddress{ + host1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), PrefixLen: 64, }, } - routerNIC1IPv6Addr := tcpip.ProtocolAddress{ + routerNIC1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - routerNIC2IPv6Addr := tcpip.ProtocolAddress{ + routerNIC2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr := tcpip.ProtocolAddress{ + host2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::2").To16()), PrefixLen: 64, }, } +) + +func setupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) { + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + { + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + { + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) +} + +// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner +// link endpoint and checks the destination link address before delivering +// network packets to the network dispatcher. +// +// See ethernet.Endpoint for more details. +func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { + var e endpointWithDestinationCheck + e.Endpoint.Init(ethernet.New(ep), &e) + return &e +} + +// endpointWithDestinationCheck is a link endpoint that checks the destination +// link address before delivering network packets to the network dispatcher. +type endpointWithDestinationCheck struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) + } +} + +func TestForwarding(t *testing.T) { + const listenPort = 8080 type endpointAndAddresses struct { serverEP tcpip.Endpoint @@ -294,113 +405,7 @@ func TestForwarding(t *testing.T) { host1Stack := stack.New(stackOpts) routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - - host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) - routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) - } - if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - { - Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - }) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) defer epsAndAddrs.serverEP.Close() diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index f85164c5b..3c25d1e0c 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -19,12 +19,14 @@ import ( "fmt" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -728,3 +730,435 @@ func TestWritePacketsLinkResolution(t *testing.T) { }) } } + +type eventType int + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved +) + +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + entry stack.NeighborEntry +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) +} + +var _ stack.NUDDispatcher = (*nudDispatcher)(nil) + +type nudDispatcher struct { + c chan eventInfo +} + +func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) waitForEvent(want eventInfo) error { + if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil +} + +// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it +// that the neighbor used for a route is reachable. +func TestTCPConfirmNeighborReachability(t *testing.T) { + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + neighborAddr tcpip.Address + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + isHost1Listener bool + }{ + { + name: "IPv4 active connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 active connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 passive connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv4 passive connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + nudDisp := nudDispatcher{ + c: make(chan eventInfo, 3), + } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, + UseNeighborCache: true, + } + host1StackOpts := stackOpts + host1StackOpts.NUDDisp = &nudDisp + + host1Stack := stack.New(host1StackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + // Add a reachable dynamic entry to our neighbor table for the remote. + { + ch := make(chan stack.LinkResolutionResult, 1) + if err := host1Stack.GetLinkAddress(host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }); err != tcpip.ErrWouldBlock { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.neighborAddr, test.netProto, err, tcpip.ErrWouldBlock) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: linkAddr2, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryAdded, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, + }); err != nil { + t.Fatalf("error waiting for initial NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the remote's neighbor entry to be stale before creating a + // TCP connection from host1 to some remote. + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err) + } + // The maximum reachable time for a neighbor is some maximum random factor + // applied to the base reachable time. + // + // See NUDConfigurations.BaseReachableTime for more information. + maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + + listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) + defer listenerEP.Close() + defer clientEP.Close() + listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} + if err := listenerEP.Bind(listenerAddr); err != nil { + t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) + } + if err := listenerEP.Listen(1); err != nil { + t.Fatalf("listenerEP.Listen(1): %s", err) + } + if err := clientEP.Connect(listenerAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, tcpip.ErrConnectStarted) + } + + // Wait for the TCP handshake to complete then make sure the neighbor is + // reachable without entering the probe state as TCP should provide NUD + // with confirmation that the neighbor is reachable (indicated by a + // successful 3-way handshake). + <-clientCH + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the neighbor to be stale again then send data to the remote. + // + // On successful transmission, the neighbor should become reachable + // without probing the neighbor as a TCP ACK would be received which is an + // indication of the neighbor being reachable. + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if test.isHost1Listener { + // If host1 is not the client, host1 does not send any data so TCP + // has no way to know it is making forward progress. Because of this, + // TCP should not mark the route reachable and NUD should go through the + // probe state. + clock.Advance(nudConfigs.DelayFirstProbeTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for probe NUD event: %s", err) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 62954d7e4..6df4e6525 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1335,6 +1335,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } + // Reaching this point means that we successfully completed the 3-way + // handshake with our peer. + // + // Completing the 3-way handshake is an indication that the route is valid + // and the remote is reachable as the only way we can complete a handshake + // is if our SYN reached the remote and their ACK reached us. + e.route.ConfirmReachable() + drained := e.drainDone != nil if drained { close(e.drainDone) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 28ef9f899..027c2a4a8 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1390,6 +1390,17 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { acked := s.sndUna.Size(ack) s.sndUna = ack + // The remote ACK-ing at least 1 byte is an indication that we have a + // full-duplex connection to the remote as the only way we will receive an + // ACK is if the remote received data that we previously sent. + // + // As of writing, linux seems to only confirm a route as reachable when + // forward progress is made which is indicated by an ACK that removes data + // from the retransmit queue. + if acked > 0 { + s.ep.route.ConfirmReachable() + } + ackLeft := acked originalOutstanding := s.outstanding for ackLeft > 0 { |