diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/pending_packets.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 14 |
7 files changed, 98 insertions, 28 deletions
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 6883045b5..03b2f2d6f 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -83,7 +83,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe got, ch, err := c.get(addr, linkRes, "", nil, nil) if err == tcpip.ErrWouldBlock { if attemptedResolution { - return got, tcpip.ErrNoLinkAddress + return got, tcpip.ErrTimeout } attemptedResolution = true <-ch @@ -253,8 +253,8 @@ func TestCacheResolutionFailed(t *testing.T) { before := atomic.LoadUint32(&requestCount) e.addr.Addr += "2" - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { + t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -269,8 +269,8 @@ func TestCacheResolutionTimeout(t *testing.T) { linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} e := testAddrs[0] - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { + t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4a34805b5..8a946b4fa 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -217,6 +217,16 @@ func (n *NIC) disableLocked() { ep.Disable() } + // Clear the neighbour table (including static entries) as we cannot guarantee + // that the current neighbour table will be valid when the NIC is enabled + // again. + // + // This matches linux's behaviour at the time of writing: + // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 + if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported { + panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) + } + if !n.setEnabled(false) { panic("should have only done work to disable the NIC if it was enabled") } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 664cc6fa0..5f216ca21 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -268,17 +268,6 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { } } -// SourceLinkAddress returns the source link address of the packet. -func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { - link := pk.LinkHeader().View() - - if link.IsEmpty() { - return "" - } - - return header.Ethernet(link).SourceAddress() -} - // Network returns the network header as a header.Network. // // Network should only be called when NetworkHeader has been set. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 4a3adcf33..bded8814e 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -101,10 +101,12 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } for _, p := range packets { - if cancelled { - p.route.Stats().IP.OutgoingPacketErrors.Increment() - } else if p.route.IsResolutionRequired() { + if cancelled || p.route.IsResolutionRequired() { p.route.Stats().IP.OutgoingPacketErrors.Increment() + + if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { + linkResolvableEP.HandleLinkResolutionFailure(pkt) + } } else { p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 4795208b4..924790779 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -55,7 +55,19 @@ type ControlType int // The following are the allowed values for ControlType values. // TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlNetworkUnreachable ControlType = iota + // ControlAddressUnreachable indicates that an IPv6 packet did not reach its + // destination as the destination address was unreachable. + // + // This maps to the ICMPv6 Destination Ureachable Code 3 error; see + // RFC 4443 section 3.1 for more details. + ControlAddressUnreachable ControlType = iota + ControlNetworkUnreachable + // ControlNoRoute indicates that an IPv4 packet did not reach its destination + // because the destination host was unreachable. + // + // This maps to the ICMPv4 Destination Ureachable Code 1 error; see + // RFC 791's Destination Unreachable Message section (page 4) for more + // details. ControlNoRoute ControlPacketTooBig ControlPortUnreachable @@ -503,6 +515,13 @@ type NetworkInterface interface { WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } +// LinkResolvableNetworkEndpoint handles link resolution events. +type LinkResolvableNetworkEndpoint interface { + // HandleLinkResolutionFailure is called when link resolution prevents the + // argument from having been sent. + HandleLinkResolutionFailure(*PacketBuffer) +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 856ebf6d4..4a3f937e3 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4305,3 +4305,55 @@ func TestWritePacketToRemote(t *testing.T) { } }) } + +func TestClearNeighborCacheOnNICDisable(t *testing.T) { + const ( + nicID = 1 + + ipv4Addr = tcpip.Address("\x01\x02\x03\x04") + ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04") + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, 0, "") + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err) + } + if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err) + } + if neighbors, err := s.Neighbors(nicID); err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } else if len(neighbors) != 2 { + t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors) + } + + // Disabling the NIC should clear the neighbor table. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if neighbors, err := s.Neighbors(nicID); err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } else if len(neighbors) != 0 { + t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + + // Enabling the NIC should have an empty neighbor table. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if neighbors, err := s.Neighbors(nicID); err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } else if len(neighbors) != 0 { + t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 0ff32c6ea..a2ab7537c 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -90,14 +90,14 @@ func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.Rea return tcpip.ReadResult{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } v, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, @@ -105,10 +105,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions }) _ = pkt.TransportHeader().Push(fakeTransHeaderLen) if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { - return 0, nil, err + return 0, err } - return int64(len(v)), nil, nil + return int64(len(v)), nil } // SetSockOpt sets a socket option. Currently not supported. @@ -222,7 +222,6 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * if err != nil { return } - route.ResolveWith(pkt.SourceLinkAddress()) ep := &fakeTransportEndpoint{ TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -522,8 +521,7 @@ func TestTransportSend(t *testing.T) { // Create buffer that will hold the payload. view := buffer.NewView(30) - _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { + if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("write failed: %v", err) } |