diff options
-rw-r--r-- | pkg/tcpip/link/pipe/pipe.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 61 |
7 files changed, 140 insertions, 49 deletions
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 12e246e21..d6e83a414 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -55,7 +55,22 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this // endpoint is the local address on the other end. - e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // + // Deliver the packet in a new goroutine to escape this goroutine's stack and + // avoid a deadlock when a packet triggers a response which leads the stack to + // try and take a lock it already holds. + // + // As of writing, a deadlock may occur when performing link resolution as the + // neighbor table will send a solicitation while holding a lock and the + // response advertisement will be sent in the same stack that sent the + // solictation. When the response is received, the stack attempts to take the + // same lock it already took before sending the solicitation, leading to a + // deadlock. Basically, we attempt to lock the same lock twice in the same + // call stack. + // + // TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send + // and receive queues. + go e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), })) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index c8d0323cb..9e106a161 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -547,6 +547,15 @@ func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress } +func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { + if n.neigh != nil { + entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) + return entry.LinkAddr, ch, err + } + + return n.stack.linkAddrCache.get(tcpip.FullAddress{NIC: n.ID(), Addr: addr}, linkRes, localAddr, n, onResolve) +} + func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { if n.neigh == nil { return nil, tcpip.ErrNotSupported diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0f6ec92c9..68c113b6a 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -835,27 +835,6 @@ type LinkAddressCache interface { // AddLinkAddress adds a link address to the cache. AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) - - // GetLinkAddress finds the link address corresponding to the remote address - // (e.g. IP -> MAC). - // - // Returns a link address for the remote address, if readily available. - // - // Returns ErrWouldBlock if the link address is not readily available, along - // with a notification channel for the caller to block on. Triggers address - // resolution asynchronously. - // - // If onResolve is provided, it will be called either immediately, if - // resolution is not required, or when address resolution is complete, with - // the resolved link address and whether resolution succeeded. After any - // callbacks have been called, the returned notification channel is closed. - // - // If specified, the local address must be an address local to the interface - // the neighbor cache belongs to. The local address is the source address of - // a packet prompting NUD/link address resolution. - // - // TODO(gvisor.dev/issue/5151): Don't return the link address. - GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index ac163904c..8dfde488b 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -51,10 +51,6 @@ type Route struct { // outgoingNIC is the interface this route uses to write packets. outgoingNIC *NIC - // linkCache is set if link address resolution is enabled for this protocol on - // the route's NIC. - linkCache LinkAddressCache - // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. linkRes LinkAddressResolver @@ -191,7 +187,6 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.outgoingNIC.stack } } @@ -338,19 +333,8 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { r.Release() } - if neigh := r.outgoingNIC.neigh; neigh != nil { - _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) - if err != nil { - return ch, err - } - return nil, nil - } - - _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution) - if err != nil { - return ch, err - } - return nil, nil + _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) + return ch, err } // local returns true if the route is a local route. @@ -373,7 +357,7 @@ func (r *Route) isResolutionRequiredRLocked() bool { return false } - return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil + return r.linkRes != nil } func (r *Route) isValidForOutgoing() bool { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 114643b03..281bb7a9d 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1518,19 +1518,41 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t // that AddLinkAddress for a particular address has been called. } -// GetLinkAddress implements LinkAddressCache.GetLinkAddress. +// GetLinkAddress finds the link address corresponding to the remote address. +// +// Returns a link address for the remote address, if readily available. +// +// Returns ErrNotSupported if the stack is not configured with a link address +// resolver for the specified network protocol. +// +// Returns ErrWouldBlock if the link address is not readily available, along +// with a notification channel for the caller to block on. Triggers address +// resolution asynchronously. +// +// If onResolve is provided, it will be called either immediately, if +// resolution is not required, or when address resolution is complete, with +// the resolved link address and whether resolution succeeded. After any +// callbacks have been called, the returned notification channel is closed. +// +// If specified, the local address must be an address local to the interface +// the neighbor cache belongs to. The local address is the source address of +// a packet prompting NUD/link address resolution. +// +// TODO(gvisor.dev/issue/5151): Don't return the link address. func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() - nic := s.nics[nicID] - if nic == nil { - s.mu.RUnlock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + if !ok { return "", nil, tcpip.ErrUnknownNICID } - s.mu.RUnlock() - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve) + linkRes, ok := s.linkAddrResolvers[protocol] + if !ok { + return "", nil, tcpip.ErrNotSupported + } + + return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) } // Neighbors returns all IP to MAC address associations. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 4a3f937e3..82ee066e6 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4357,3 +4357,24 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) } } + +func TestGetLinkAddressErrors(t *testing.T) { + const ( + nicID = 1 + unknownNICID = nicID + 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID) + } + if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 3f06c2145..af32d3009 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -16,6 +16,7 @@ package integration_test import ( "bytes" + "fmt" "net" "testing" @@ -395,3 +396,63 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }) } } + +func TestGetLinkAddress(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedLinkAddr bool + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, useNeighborCache := range []bool{true, false} { + t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + UseNeighborCache: useNeighborCache, + } + + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + + for i := 0; i < 2; i++ { + addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {}) + var want *tcpip.Error + if i == 0 { + want = tcpip.ErrWouldBlock + } + if err != want { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want) + } + + if i == 0 { + <-ch + continue + } + + if addr != linkAddr2 { + t.Fatalf("got addr = %s, want = %s", addr, linkAddr2) + } + } + }) + } + }) + } +} |