summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go10
-rw-r--r--pkg/tcpip/stack/nic.go10
-rw-r--r--pkg/tcpip/stack/packet_buffer.go11
-rw-r--r--pkg/tcpip/stack/pending_packets.go8
-rw-r--r--pkg/tcpip/stack/registration.go21
-rw-r--r--pkg/tcpip/stack/stack_test.go52
-rw-r--r--pkg/tcpip/stack/transport_test.go14
7 files changed, 98 insertions, 28 deletions
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 6883045b5..03b2f2d6f 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -83,7 +83,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
got, ch, err := c.get(addr, linkRes, "", nil, nil)
if err == tcpip.ErrWouldBlock {
if attemptedResolution {
- return got, tcpip.ErrNoLinkAddress
+ return got, tcpip.ErrTimeout
}
attemptedResolution = true
<-ch
@@ -253,8 +253,8 @@ func TestCacheResolutionFailed(t *testing.T) {
before := atomic.LoadUint32(&requestCount)
e.addr.Addr += "2"
- if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
+ t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
}
if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
@@ -269,8 +269,8 @@ func TestCacheResolutionTimeout(t *testing.T) {
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
e := testAddrs[0]
- if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
+ t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 4a34805b5..8a946b4fa 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -217,6 +217,16 @@ func (n *NIC) disableLocked() {
ep.Disable()
}
+ // Clear the neighbour table (including static entries) as we cannot guarantee
+ // that the current neighbour table will be valid when the NIC is enabled
+ // again.
+ //
+ // This matches linux's behaviour at the time of writing:
+ // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371
+ if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported {
+ panic(fmt.Sprintf("n.clearNeighbors(): %s", err))
+ }
+
if !n.setEnabled(false) {
panic("should have only done work to disable the NIC if it was enabled")
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 664cc6fa0..5f216ca21 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -268,17 +268,6 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
}
}
-// SourceLinkAddress returns the source link address of the packet.
-func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress {
- link := pk.LinkHeader().View()
-
- if link.IsEmpty() {
- return ""
- }
-
- return header.Ethernet(link).SourceAddress()
-}
-
// Network returns the network header as a header.Network.
//
// Network should only be called when NetworkHeader has been set.
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 4a3adcf33..bded8814e 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -101,10 +101,12 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
}
for _, p := range packets {
- if cancelled {
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
- } else if p.route.IsResolutionRequired() {
+ if cancelled || p.route.IsResolutionRequired() {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
+
+ if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok {
+ linkResolvableEP.HandleLinkResolutionFailure(pkt)
+ }
} else {
p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 4795208b4..924790779 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -55,7 +55,19 @@ type ControlType int
// The following are the allowed values for ControlType values.
// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlNetworkUnreachable ControlType = iota
+ // ControlAddressUnreachable indicates that an IPv6 packet did not reach its
+ // destination as the destination address was unreachable.
+ //
+ // This maps to the ICMPv6 Destination Ureachable Code 3 error; see
+ // RFC 4443 section 3.1 for more details.
+ ControlAddressUnreachable ControlType = iota
+ ControlNetworkUnreachable
+ // ControlNoRoute indicates that an IPv4 packet did not reach its destination
+ // because the destination host was unreachable.
+ //
+ // This maps to the ICMPv4 Destination Ureachable Code 1 error; see
+ // RFC 791's Destination Unreachable Message section (page 4) for more
+ // details.
ControlNoRoute
ControlPacketTooBig
ControlPortUnreachable
@@ -503,6 +515,13 @@ type NetworkInterface interface {
WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
+// LinkResolvableNetworkEndpoint handles link resolution events.
+type LinkResolvableNetworkEndpoint interface {
+ // HandleLinkResolutionFailure is called when link resolution prevents the
+ // argument from having been sent.
+ HandleLinkResolutionFailure(*PacketBuffer)
+}
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 856ebf6d4..4a3f937e3 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -4305,3 +4305,55 @@ func TestWritePacketToRemote(t *testing.T) {
}
})
}
+
+func TestClearNeighborCacheOnNICDisable(t *testing.T) {
+ const (
+ nicID = 1
+
+ ipv4Addr = tcpip.Address("\x01\x02\x03\x04")
+ ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04")
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 0, "")
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err)
+ }
+ if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err)
+ }
+ if neighbors, err := s.Neighbors(nicID); err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ } else if len(neighbors) != 2 {
+ t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors)
+ }
+
+ // Disabling the NIC should clear the neighbor table.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ if neighbors, err := s.Neighbors(nicID); err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ }
+
+ // Enabling the NIC should have an empty neighbor table.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if neighbors, err := s.Neighbors(nicID); err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 0ff32c6ea..a2ab7537c 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -90,14 +90,14 @@ func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.Rea
return tcpip.ReadResult{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
- return 0, nil, tcpip.ErrNoRoute
+ return 0, tcpip.ErrNoRoute
}
v, err := p.FullPayload()
if err != nil {
- return 0, nil, err
+ return 0, err
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
@@ -105,10 +105,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
})
_ = pkt.TransportHeader().Push(fakeTransHeaderLen)
if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
- return 0, nil, err
+ return 0, err
}
- return int64(len(v)), nil, nil
+ return int64(len(v)), nil
}
// SetSockOpt sets a socket option. Currently not supported.
@@ -222,7 +222,6 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
if err != nil {
return
}
- route.ResolveWith(pkt.SourceLinkAddress())
ep := &fakeTransportEndpoint{
TransportEndpointInfo: stack.TransportEndpointInfo{
@@ -522,8 +521,7 @@ func TestTransportSend(t *testing.T) {
// Create buffer that will hold the payload.
view := buffer.NewView(30)
- _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
- if err != nil {
+ if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("write failed: %v", err)
}