summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/devices/tundev/BUILD1
-rw-r--r--pkg/sentry/devices/tundev/tundev.go14
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/net_tun.go52
-rw-r--r--pkg/tcpip/header/eth.go16
-rw-r--r--pkg/tcpip/header/eth_test.go47
-rw-r--r--pkg/tcpip/link/pipe/BUILD15
-rw-r--r--pkg/tcpip/link/pipe/pipe.go124
-rw-r--r--pkg/tcpip/link/tun/device.go42
-rw-r--r--pkg/tcpip/network/arp/arp.go25
-rw-r--r--pkg/tcpip/network/ip_test.go118
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go40
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go26
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go203
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go51
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go221
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go36
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go2
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/forwarding_test.go (renamed from pkg/tcpip/stack/forwarder_test.go)12
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go4
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go5
-rw-r--r--pkg/tcpip/stack/nic.go125
-rw-r--r--pkg/tcpip/stack/nic_test.go10
-rw-r--r--pkg/tcpip/stack/pending_packets.go (renamed from pkg/tcpip/stack/forwarder.go)60
-rw-r--r--pkg/tcpip/stack/registration.go33
-rw-r--r--pkg/tcpip/stack/route.go48
-rw-r--r--pkg/tcpip/stack/stack.go42
-rw-r--r--pkg/tcpip/stack/stack_test.go62
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go378
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go219
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go8
34 files changed, 1636 insertions, 414 deletions
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
index 71c59287c..14a8bf9cd 100644
--- a/pkg/sentry/devices/tundev/BUILD
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -17,6 +17,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index 0b701a289..655ea549b 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -16,6 +16,8 @@
package tundev
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -26,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -84,7 +87,16 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg
return 0, err
}
flags := usermem.ByteOrder.Uint16(req.Data[:])
- return 0, fd.device.SetIff(stack.Stack, req.Name(), flags)
+ created, err := fd.device.SetIff(stack.Stack, req.Name(), flags)
+ if err == nil && created {
+ // Always start with an ARP address for interfaces so they can handle ARP
+ // packets.
+ nicID := fd.device.NICID()
+ if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID))
+ }
+ }
+ return 0, err
case linux.TUNGETIFF:
var req linux.IFReq
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 9379a4d7b..6b7b451b8 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/socket/netstack",
"//pkg/syserror",
"//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 5f8c9b5a2..19ffdec47 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -15,6 +15,8 @@
package dev
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -25,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -60,7 +63,7 @@ func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMod
}
// GetFile implements fs.InodeOperations.GetFile.
-func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+func (*netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
}
@@ -80,12 +83,12 @@ type netTunFileOperations struct {
var _ fs.FileOperations = (*netTunFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (fops *netTunFileOperations) Release(ctx context.Context) {
- fops.device.Release(ctx)
+func (n *netTunFileOperations) Release(ctx context.Context) {
+ n.device.Release(ctx)
}
// Ioctl implements fs.FileOperations.Ioctl.
-func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
request := args[1].Uint()
data := args[2].Pointer()
@@ -109,16 +112,25 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u
return 0, err
}
flags := usermem.ByteOrder.Uint16(req.Data[:])
- return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+ created, err := n.device.SetIff(stack.Stack, req.Name(), flags)
+ if err == nil && created {
+ // Always start with an ARP address for interfaces so they can handle ARP
+ // packets.
+ nicID := n.device.NICID()
+ if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID))
+ }
+ }
+ return 0, err
case linux.TUNGETIFF:
var req linux.IFReq
- copy(req.IFName[:], fops.device.Name())
+ copy(req.IFName[:], n.device.Name())
// Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
// there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
- flags := fops.device.Flags() | linux.IFF_NOFILTER
+ flags := n.device.Flags() | linux.IFF_NOFILTER
usermem.ByteOrder.PutUint16(req.Data[:], flags)
_, err := req.CopyOut(t, data)
@@ -130,41 +142,41 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u
}
// Write implements fs.FileOperations.Write.
-func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+func (n *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
data := make([]byte, src.NumBytes())
if _, err := src.CopyIn(ctx, data); err != nil {
return 0, err
}
- return fops.device.Write(data)
+ return n.device.Write(data)
}
// Read implements fs.FileOperations.Read.
-func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- data, err := fops.device.Read()
+func (n *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := n.device.Read()
if err != nil {
return 0, err
}
- n, err := dst.CopyOut(ctx, data)
- if n > 0 && n < len(data) {
+ bytesCopied, err := dst.CopyOut(ctx, data)
+ if bytesCopied > 0 && bytesCopied < len(data) {
// Not an error for partial copying. Packet truncated.
err = nil
}
- return int64(n), err
+ return int64(bytesCopied), err
}
// Readiness implements watier.Waitable.Readiness.
-func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
- return fops.device.Readiness(mask)
+func (n *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return n.device.Readiness(mask)
}
// EventRegister implements watier.Waitable.EventRegister.
-func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- fops.device.EventRegister(e, mask)
+func (n *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ n.device.EventRegister(e, mask)
}
// EventUnregister implements watier.Waitable.EventUnregister.
-func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
- fops.device.EventUnregister(e)
+func (n *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ n.device.EventUnregister(e)
}
// isNetTunSupported returns whether /dev/net/tun device is supported for s.
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index eaface8cb..95ade0e5c 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -117,25 +117,31 @@ func (b Ethernet) Encode(e *EthernetFields) {
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
-// IsValidUnicastEthernetAddress returns true if addr is a valid unicast
+// IsMulticastEthernetAddress returns true if the address is a multicast
+// ethernet address.
+func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool {
+ if len(addr) != EthernetAddressSize {
+ return false
+ }
+
+ return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0
+}
+
+// IsValidUnicastEthernetAddress returns true if the address is a unicast
// ethernet address.
func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
- // Must be of the right length.
if len(addr) != EthernetAddressSize {
return false
}
- // Must not be unspecified.
if addr == unspecifiedEthernetAddress {
return false
}
- // Must not be a multicast.
if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
return false
}
- // addr is a valid unicast ethernet address.
return true
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 14413f2ce..3bc8b2b21 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -67,6 +67,53 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
}
}
+func TestIsMulticastEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.LinkAddress
+ expected bool
+ }{
+ {
+ "Nil",
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty",
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "InvalidLength",
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Unspecified",
+ unspecifiedEthernetAddress,
+ false,
+ },
+ {
+ "Multicast",
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ true,
+ },
+ {
+ "Unicast",
+ tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := IsMulticastEthernetAddress(test.addr); got != test.expected {
+ t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected)
+ }
+ })
+ }
+}
+
func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD
new file mode 100644
index 000000000..9f31c1ffc
--- /dev/null
+++ b/pkg/tcpip/link/pipe/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pipe",
+ srcs = ["pipe.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
new file mode 100644
index 000000000..76f563811
--- /dev/null
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe provides the implementation of pipe-like data-link layer
+// endpoints. Such endpoints allow packets to be sent between two interfaces.
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns both ends of a new pipe.
+func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) {
+ ep1 := &Endpoint{
+ linkAddr: linkAddr1,
+ capabilities: capabilities,
+ }
+ ep2 := &Endpoint{
+ linkAddr: linkAddr2,
+ linked: ep1,
+ capabilities: capabilities,
+ }
+ ep1.linked = ep2
+ return ep1, ep2
+}
+
+// Endpoint is one end of a pipe.
+type Endpoint struct {
+ capabilities stack.LinkEndpointCapabilities
+ linkAddr tcpip.LinkAddress
+ dispatcher stack.NetworkDispatcher
+ linked *Endpoint
+ onWritePacket func(*stack.PacketBuffer)
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if !e.linked.IsAttached() {
+ return nil
+ }
+
+ // The pipe endpoint will accept all multicast/broadcast link traffic and only
+ // unicast traffic destined to itself.
+ if len(e.linked.linkAddr) != 0 &&
+ r.RemoteLinkAddress != e.linked.linkAddr &&
+ r.RemoteLinkAddress != header.EthernetBroadcastAddress &&
+ !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) {
+ return nil
+ }
+
+ e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ panic("not implemented")
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// Wait implements stack.LinkEndpoint.
+func (*Endpoint) Wait() {}
+
+// MTU implements stack.LinkEndpoint.
+func (*Endpoint) MTU() uint32 {
+ return header.IPv6MinimumMTU
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.capabilities
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (*Endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index b6ddbe81e..f94491026 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -76,13 +76,29 @@ func (d *Device) Release(ctx context.Context) {
}
}
+// NICID returns the NIC ID of the device.
+//
+// Must only be called after the device has been attached to an endpoint.
+func (d *Device) NICID() tcpip.NICID {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+
+ if d.endpoint == nil {
+ panic("called NICID on a device that has not been attached")
+ }
+
+ return d.endpoint.nicID
+}
+
// SetIff services TUNSETIFF ioctl(2) request.
-func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+//
+// Returns true if a new NIC was created; false if an existing one was attached.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) {
d.mu.Lock()
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
// Input validations.
@@ -90,7 +106,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
isTap := flags&linux.IFF_TAP != 0
supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
prefix := "tun"
@@ -103,32 +119,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
linkCaps |= stack.CapabilityResolutionRequired
}
- endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
d.endpoint = endpoint
d.notifyHandle = d.endpoint.AddNotify(d)
d.flags = flags
- return nil
+ return created, nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
- if nic, found := s.GetNICByName(name); found {
- endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if linkEP := s.GetLinkEndpointByName(name); linkEP != nil {
+ endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, syserror.EOPNOTSUPP
+ return nil, false, syserror.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
continue
}
- return endpoint, nil
+ return endpoint, false, nil
}
}
@@ -151,12 +167,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
})
switch err {
case nil:
- return endpoint, nil
+ return endpoint, true, nil
case tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, false, syserror.EINVAL
}
}
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index b47a7be51..7df77c66e 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -49,7 +49,6 @@ type endpoint struct {
enabled uint32
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
}
@@ -92,12 +91,12 @@ func (e *endpoint) DefaultTTL() uint8 {
}
func (e *endpoint) MTU() uint32 {
- lmtu := e.linkEP.MTU()
+ lmtu := e.nic.MTU()
return lmtu - uint32(e.MaxHeaderLength())
}
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.ARPSize
+ return e.nic.MaxHeaderLength() + header.ARPSize
}
func (e *endpoint) Close() {
@@ -154,17 +153,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ origSender := h.HardwareAddressSender()
+ r.RemoteLinkAddress = tcpip.LinkAddress(origSender)
+ respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
- packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.HardwareAddressTarget(), origSender)
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt)
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
@@ -207,7 +214,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
e := &endpoint{
protocol: p,
nic: nic,
- linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
}
@@ -223,6 +229,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
+ NetProto: ProtocolNumber,
RemoteLinkAddress: remoteLinkAddr,
}
if len(r.RemoteLinkAddress) == 0 {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 6861cfdaf..d436873b6 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -270,7 +270,7 @@ func buildDummyStack(t *testing.T) *stack.Stack {
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
- tester testObject
+ testObject
mu struct {
sync.RWMutex
@@ -302,10 +302,6 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
-func (t *testInterface) LinkEndpoint() stack.LinkEndpoint {
- return &t.tester
-}
-
func TestSourceAddressValidation(t *testing.T) {
rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
@@ -517,7 +513,7 @@ func TestIPv4Send(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
@@ -538,10 +534,10 @@ func TestIPv4Send(t *testing.T) {
})
// Issue the write.
- nic.tester.protocol = 123
- nic.tester.srcAddr = localIPv4Addr
- nic.tester.dstAddr = remoteIPv4Addr
- nic.tester.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv4Addr
+ nic.testObject.dstAddr = remoteIPv4Addr
+ nic.testObject.contents = payload
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -560,12 +556,12 @@ func TestIPv4Receive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -590,10 +586,10 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -606,8 +602,8 @@ func TestIPv4Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -640,11 +636,11 @@ func TestIPv4ReceiveControl(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -691,16 +687,16 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = view[dataOffset:]
- nic.tester.typ = c.expectedTyp
- nic.tester.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; nic.tester.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
@@ -710,12 +706,12 @@ func TestIPv4FragmentationReceive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -758,10 +754,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -776,8 +772,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
// Send second segment.
@@ -788,8 +784,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -797,7 +793,7 @@ func TestIPv6Send(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
@@ -821,10 +817,10 @@ func TestIPv6Send(t *testing.T) {
})
// Issue the write.
- nic.tester.protocol = 123
- nic.tester.srcAddr = localIPv6Addr
- nic.tester.dstAddr = remoteIPv6Addr
- nic.tester.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv6Addr
+ nic.testObject.dstAddr = remoteIPv6Addr
+ nic.testObject.contents = payload
r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
@@ -843,11 +839,11 @@ func TestIPv6Receive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -871,10 +867,10 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv6Addr
- nic.tester.dstAddr = localIPv6Addr
- nic.tester.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
@@ -888,8 +884,8 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -931,11 +927,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -994,19 +990,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv6Addr
- nic.tester.dstAddr = localIPv6Addr
- nic.tester.contents = view[dataOffset:]
- nic.tester.typ = c.expectedTyp
- nic.tester.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; nic.tester.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index ee2c23e91..7fc12e229 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -32,6 +32,7 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index eab9a530c..3407755ed 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -102,8 +102,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
@@ -119,9 +117,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
}
defer r.Release()
- // Use the remote link address from the incoming packet.
- r.ResolveWith(remoteLinkAddr)
-
// TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
// header information, we may have to change this code to handle the
// ICMP header no longer being in the data buffer.
@@ -244,13 +239,7 @@ func (*icmpReasonProtoUnreachable) isICMPReason() {}
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- sent := r.Stats().ICMP.V4PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
@@ -279,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
return nil
}
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ sent := p.stack.Stats().ICMP.V4PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
networkHeader := pkt.NetworkHeader().View()
transportHeader := pkt.TransportHeader().View()
@@ -329,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// least 8 bytes of the payload must be included. Today linux and other
// systems implement the RFC 1812 definition and not the original
// requirement. We treat 8 bytes as the minimum but will try send more.
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv4MinimumProcessableDatagramSize {
mtu = header.IPv4MinimumProcessableDatagramSize
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
@@ -378,11 +386,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
counter := sent.DstUnreachable
- if err := r.WritePacket(
+ if err := route.WritePacket(
nil, /* gso */
stack.NetworkHeaderParams{
Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
+ TTL: route.DefaultTTL(),
TOS: stack.DefaultTOS,
},
icmpPkt,
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 99274dd45..115fb1ab0 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -66,7 +66,6 @@ var _ stack.NetworkEndpoint = (*endpoint)(nil)
type endpoint struct {
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
protocol *protocol
@@ -87,7 +86,6 @@ type endpoint struct {
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: nic.LinkEndpoint(),
dispatcher: dispatcher,
protocol: p,
}
@@ -178,18 +176,18 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv4MaximumHeaderSize
+ return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
// GSOMaxSize returns the maximum GSO packet size.
func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ if gso, ok := e.nic.(stack.GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@@ -210,7 +208,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint
for {
fragPkt, more := buildNextFragment(&pf, networkHeader)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1))
return err
}
@@ -283,10 +281,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, e.linkEP.MTU(), pkt)
+ if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, e.nic.MTU(), pkt)
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
@@ -316,7 +314,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
@@ -343,7 +341,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
@@ -404,7 +402,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return nil
}
- if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
@@ -512,13 +510,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC: 1122 Section 3.2.2.1
// A host SHOULD generate Destination Unreachable messages with code:
// 2 (Protocol Unreachable), when the designated transport protocol
// is not supported
- _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index f250a3cde..9916d783f 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -16,6 +16,7 @@ package ipv4_test
import (
"bytes"
+ "context"
"encoding/hex"
"math"
"net"
@@ -28,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -1492,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
lm.limit--
return false, false
}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize,
+ TTL: ipv4.DefaultTTL,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a ARP request since link address resolution should be
+ // performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != arp.ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(p.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address)
+ }
+ }
+
+ // Send an ARP reply to complete link address resolution.
+ {
+ hdr := buffer.View(make([]byte, header.ARPSize))
+ packet := header.ARP(hdr)
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), host2NICLinkAddr)
+ copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address)
+ copy(packet.HardwareAddressTarget(), host1NICLinkAddr)
+ copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address)
+ e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 37c169a5d..7be35c78b 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -440,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
localAddr := r.LocalAddress
@@ -456,9 +454,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
defer r.Release()
- // Use the link address from the source of the original packet.
- r.ResolveWith(remoteLinkAddr)
-
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data,
@@ -742,14 +737,7 @@ func (*icmpReasonPortUnreachable) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- stats := r.Stats().ICMP
- sent := stats.V6PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// Only send ICMP error if the address is not a multicast v6
// address and the source is not the unspecified address.
//
@@ -780,6 +768,26 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
return nil
}
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ stats := p.stack.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
@@ -806,11 +814,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// packet that caused the error) as possible without making
// the error message packet exceed the minimum IPv6 MTU
// [IPv6].
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv6MinimumMTU {
mtu = header.IPv6MinimumMTU
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
available := int(mtu) - headerLen
if available < header.IPv6MinimumSize {
return nil
@@ -843,9 +851,16 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data))
- err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt)
- if err != nil {
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data))
+ if err := route.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ newPkt,
+ ); err != nil {
sent.Dropped.Increment()
return err
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 6f13339a4..3affcc4e4 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -16,9 +16,11 @@ package ipv6
import (
"context"
+ "net"
"reflect"
"strings"
"testing"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -28,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -40,6 +43,9 @@ const (
defaultChannelSize = 1
defaultMTU = 65536
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 30 * time.Second
)
var (
@@ -110,7 +116,9 @@ func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
var _ stack.NetworkInterface = (*testInterface)(nil)
-type testInterface struct{}
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
func (*testInterface) ID() tcpip.NICID {
return 0
@@ -128,10 +136,6 @@ func (*testInterface) Enabled() bool {
return true
}
-func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
- return nil
-}
-
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -1281,3 +1285,210 @@ func TestLinkAddressRequest(t *testing.T) {
))
}
}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a neighbor solicitation since link address resolution should
+ // be performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}),
+ ))
+ }
+
+ // Send a neighbor advertisement to complete link address resolution.
+ {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 826342c4f..4f360df2c 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -66,7 +66,6 @@ var _ NDPEndpoint = (*endpoint)(nil)
type endpoint struct {
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
dispatcher stack.TransportDispatcher
@@ -364,18 +363,18 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
// GSOMaxSize returns the maximum GSO packet size.
func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ if gso, ok := e.nic.(stack.GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@@ -396,7 +395,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
}
func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
- return pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone)
+ return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone)
}
// handleFragments fragments pkt and calls the handler function on each
@@ -477,19 +476,19 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.linkEP.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
// WritePackets(). It'll be faster but cost more memory.
- return e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
})
r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
@@ -511,7 +510,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
e.addIPHeader(r, pb, params)
if e.packetMustBeFragmented(pb, gso) {
current := pb
- _, _, err := e.handleFragments(r, gso, e.linkEP.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(current, fragPkt)
current = current.Next()
@@ -536,7 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
@@ -563,7 +562,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
// Dropped packets aren't errors, so include them in
@@ -643,7 +642,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
if previousHeaderStart != 0 {
- _ = returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: previousHeaderStart,
}, pkt)
@@ -682,7 +681,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// ICMP Parameter Problem, Code 2, message to the packet's
// Source Address, pointing to the unrecognized Option Type.
//
- _ = returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
@@ -707,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
if extHdr.SegmentsLeft() != 0 {
- _ = returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -859,7 +858,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// ICMP Parameter Problem, Code 2, message to the packet's
// Source Address, pointing to the unrecognized Option Type.
//
- _ = returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
@@ -896,7 +895,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// message with Code 4 in response to a packet for which the
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC 8200 section 4. (page 7):
// Extension headers are numbered from IANA IP Protocol Numbers
@@ -917,7 +916,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
//
// Which when taken together indicate that an unknown protocol should
// be treated as an unrecognized next header value.
- _ = returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -927,7 +926,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
default:
- _ = returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -1302,7 +1301,6 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 48a4c65e3..40da011f8 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1289,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
//
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
- linkAddr := ndp.ep.linkEP.LinkAddress()
+ linkAddr := ndp.ep.nic.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
return false
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 2eaeab779..eba97334e 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -56,7 +56,6 @@ go_library(
srcs = [
"addressable_endpoint_state.go",
"conntrack.go",
- "forwarder.go",
"headertype_string.go",
"icmp_rate_limit.go",
"iptables.go",
@@ -73,6 +72,7 @@ go_library(
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
+ "pending_packets.go",
"rand.go",
"registration.go",
"route.go",
@@ -123,7 +123,6 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
@@ -139,7 +138,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "forwarder_test.go",
+ "forwarding_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarding_test.go
index 4e4b00a92..cf042309e 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -48,10 +48,9 @@ const (
type fwdTestNetworkEndpoint struct {
AddressableEndpointState
- nicID tcpip.NICID
+ nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
- ep LinkEndpoint
}
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
@@ -67,7 +66,7 @@ func (*fwdTestNetworkEndpoint) Enabled() bool {
func (*fwdTestNetworkEndpoint) Disable() {}
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
@@ -80,7 +79,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+ return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -99,7 +98,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
@@ -159,10 +158,9 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
e := &fwdTestNetworkEndpoint{
- nicID: nic.ID(),
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 9a72bec79..4d69a4de1 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
// There is no need to log the error here; the NUD implementation may
// assume a working link. A valid link should be the responsibility of
// the NIC/stack.LinkEndpoint.
@@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index a265fff0a..e79abebca 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -227,8 +227,9 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
- id: entryTestNICID,
- linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+
+ id: entryTestNICID,
stack: &Stack{
clock: clock,
nudDisp: &disp,
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 6cf54cc89..8828cc5fe 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -32,10 +32,11 @@ var _ NetworkInterface = (*NIC)(nil)
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
+ LinkEndpoint
+
stack *Stack
id tcpip.NICID
name string
- linkEP LinkEndpoint
context NICContext
stats NICStats
@@ -91,10 +92,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
+ LinkEndpoint: ep,
+
stack: stack,
id: id,
name: name,
- linkEP: ep,
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
@@ -130,7 +132,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
- nic.linkEP.Attach(nic)
+ nic.LinkEndpoint.Attach(nic)
return nic
}
@@ -220,7 +222,7 @@ func (n *NIC) remove() *tcpip.Error {
}
// Detach from link endpoint, so no packet comes in.
- n.linkEP.Attach(nil)
+ n.LinkEndpoint.Attach(nil)
return nil
}
@@ -240,7 +242,64 @@ func (n *NIC) isPromiscuousMode() bool {
// IsLoopback implements NetworkInterface.
func (n *NIC) IsLoopback() bool {
- return n.linkEP.Capabilities()&CapabilityLoopback != 0
+ return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0
+}
+
+// WritePacket implements NetworkLinkEndpoint.
+func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // As per relevant RFCs, we should queue packets while we wait for link
+ // resolution to complete.
+ //
+ // RFC 1122 section 2.3.2.2 (for IPv4):
+ // The link layer SHOULD save (rather than discard) at least
+ // one (the latest) packet of each set of packets destined to
+ // the same unresolved IP address, and transmit the saved
+ // packet when the address has been resolved.
+ //
+ // RFC 4861 section 5.2 (for IPv6):
+ // Once the IP address of the next-hop node is known, the sender
+ // examines the Neighbor Cache for link-layer information about that
+ // neighbor. If no entry exists, the sender creates one, sets its state
+ // to INCOMPLETE, initiates Address Resolution, and then queues the data
+ // packet pending completion of address resolution.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ r := r.Clone()
+ n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ return nil
+ }
+ return err
+ }
+
+ return n.writePacket(r, gso, protocol, pkt)
+}
+
+func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Size()
+
+ if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
+ return err
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
+}
+
+// WritePackets implements NetworkLinkEndpoint.
+func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
+ // is being peformed like WritePacket.
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
+ n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Size()
+ }
+
+ n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ return writtenPackets, err
}
// setSpoofing enables or disables address spoofing.
@@ -525,7 +584,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// If no local link layer address is provided, assume it was sent
// directly to this NIC.
if local == "" {
- local = n.linkEP.LinkAddress()
+ local = n.LinkEndpoint.LinkAddress()
}
// Are any packet type sockets listening for this network protocol?
@@ -605,7 +664,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n := r.nic
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
@@ -620,21 +679,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
- // TODO(b/128629022): move this logic to route.WritePacket.
// TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
- // forwarder will release route.
- return
- }
+
+ // pkt may have set its header and may not have enough headroom for
+ // link-layer header for the other link to prepend. Here we create a new
+ // packet to forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- r.Release()
- return
}
- // The link-address resolution finished immediately.
- n.forwardPacket(&r, protocol, pkt)
r.Release()
return
}
@@ -658,34 +717,11 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
- n.linkEP.AddHeader(local, remote, protocol, p)
+ n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
-
- // pkt may have set its header and may not have enough headroom for link-layer
- // header for the other link to prepend. Here we create a new packet to
- // forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- })
-
- // WritePacket takes ownership of fwdPkt, calculate numBytes first.
- numBytes := fwdPkt.Size()
-
- if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return
- }
-
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
@@ -796,11 +832,6 @@ func (n *NIC) Name() string {
return n.name
}
-// LinkEndpoint implements NetworkInterface.
-func (n *NIC) LinkEndpoint() LinkEndpoint {
- return n.linkEP
-}
-
// nudConfigs gets the NUD configurations for n.
func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
if n.neigh == nil {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index fdd49b77f..97a96af62 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -33,8 +33,7 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
type testIPv6Endpoint struct {
AddressableEndpointState
- nicID tcpip.NICID
- linkEP LinkEndpoint
+ nic NetworkInterface
protocol *testIPv6Protocol
invalidatedRtr tcpip.Address
@@ -57,12 +56,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 {
// MTU implements NetworkEndpoint.MTU.
func (e *testIPv6Endpoint) MTU() uint32 {
- return e.linkEP.MTU() - header.IPv6MinimumSize
+ return e.nic.MTU() - header.IPv6MinimumSize
}
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
// WritePacket implements NetworkEndpoint.WritePacket.
@@ -134,8 +133,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint implements NetworkProtocol.NewEndpoint.
func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
e := &testIPv6Endpoint{
- nicID: nic.ID(),
- linkEP: nic.LinkEndpoint(),
+ nic: nic,
protocol: p,
}
e.AddressableEndpointState.Init(e)
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/pending_packets.go
index 3eff141e6..f838eda8d 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -29,60 +29,60 @@ const (
)
type pendingPacket struct {
- nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
pkt *PacketBuffer
}
-type forwardQueue struct {
+// packetsPendingLinkResolution is a queue of packets pending link resolution.
+//
+// Once link resolution completes successfully, the packets will be written.
+type packetsPendingLinkResolution struct {
sync.Mutex
// The packets to send once the resolver completes.
- packets map[<-chan struct{}][]*pendingPacket
+ packets map[<-chan struct{}][]pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
cancelChans []chan struct{}
}
-func newForwardQueue() *forwardQueue {
- return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+func (f *packetsPendingLinkResolution) init() {
+ f.Lock()
+ defer f.Unlock()
+ f.packets = make(map[<-chan struct{}][]pendingPacket)
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- shouldWait := false
-
+func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
f.Lock()
+ defer f.Unlock()
+
packets, ok := f.packets[ch]
- if !ok {
- shouldWait = true
- }
- for len(packets) == maxPendingPacketsPerResolution {
+ if len(packets) == maxPendingPacketsPerResolution {
p := packets[0]
+ packets[0] = pendingPacket{}
packets = packets[1:]
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
p.route.Release()
}
+
if l := len(packets); l >= maxPendingPacketsPerResolution {
panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
}
- f.packets[ch] = append(packets, &pendingPacket{
- nic: n,
+
+ f.packets[ch] = append(packets, pendingPacket{
route: r,
- proto: protocol,
+ proto: proto,
pkt: pkt,
})
- f.Unlock()
- if !shouldWait {
+ if ok {
return
}
// Wait for the link-address resolution to complete.
- // Start a goroutine with a forwarding-cancel channel so that we can
- // limit the maximum number of goroutines running concurrently.
- cancel := f.newCancelChannel()
+ cancel := f.newCancelChannelLocked()
go func() {
cancelled := false
select {
@@ -92,17 +92,21 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
}
f.Lock()
- packets := f.packets[ch]
+ packets, ok := f.packets[ch]
delete(f.packets, ch)
f.Unlock()
+ if !ok {
+ panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
+ }
+
for _, p := range packets {
if cancelled {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else if _, err := p.route.Resolve(nil); err != nil {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
@@ -112,12 +116,10 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
// newCancelChannel creates a channel that can cancel a pending forwarding
// activity. The oldest channel is closed if the number of open channels would
// exceed maxPendingResolutions.
-func (f *forwardQueue) newCancelChannel() chan struct{} {
- f.Lock()
- defer f.Unlock()
-
+func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
if len(f.cancelChans) == maxPendingResolutions {
ch := f.cancelChans[0]
+ f.cancelChans[0] = nil
f.cancelChans = f.cancelChans[1:]
close(ch)
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index be9bd8042..defb9129b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -475,6 +475,8 @@ type NDPEndpoint interface {
// NetworkInterface is a network interface.
type NetworkInterface interface {
+ NetworkLinkEndpoint
+
// ID returns the interface's ID.
ID() tcpip.NICID
@@ -488,9 +490,6 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
-
- // LinkEndpoint returns the link endpoint backing the interface.
- LinkEndpoint() LinkEndpoint
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -663,22 +662,15 @@ const (
CapabilitySoftwareGSO
)
-// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
-// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint. When a link header exists,
-// it sets each PacketBuffer's LinkHeader field before passing it up the
-// stack.
-type LinkEndpoint interface {
+// NetworkLinkEndpoint is a data-link layer that supports sending network
+// layer packets.
+type NetworkLinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
// physical network doesn't exist, the limit is generally 64k, which
// includes the maximum size of an IP packet.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the data link (and
// lower level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -686,7 +678,7 @@ type LinkEndpoint interface {
MaxHeaderLength() uint16
// LinkAddress returns the link address (typically a MAC) of the
- // link endpoint.
+ // endpoint.
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the
@@ -706,6 +698,19 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+}
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each PacketBuffer's LinkHeader field before passing it up the
+// stack.
+type LinkEndpoint interface {
+ NetworkLinkEndpoint
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header. It takes ownership of vv.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index cc39c9a6a..710a02ac3 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -72,21 +72,20 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
- linkEP := nic.LinkEndpoint()
r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: linkEP.LinkAddress(),
+ LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
RemoteAddress: remoteAddr,
addressEndpoint: addressEndpoint,
nic: nic,
Loop: loop,
}
- if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = nic.stack
+ r.linkCache = r.nic.stack
}
}
@@ -116,7 +115,7 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint().Capabilities()
+ return r.nic.LinkEndpoint.Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
@@ -127,12 +126,6 @@ func (r *Route) GSOMaxSize() uint32 {
return 0
}
-// ResolveWith immediately resolves a route with the specified remote link
-// address.
-func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
-}
-
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
@@ -208,16 +201,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
return tcpip.ErrInvalidEndpointState
}
- // WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Size()
-
- if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil {
- return err
- }
-
- r.nic.stats.Tx.Packets.Increment()
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- return nil
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
@@ -227,15 +211,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
return 0, tcpip.ErrInvalidEndpointState
}
- n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
- r.nic.stats.Tx.Packets.IncrementBy(uint64(n))
- writtenBytes := 0
- for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
- writtenBytes += pb.Size()
- }
-
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
- return n, err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -245,15 +221,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Data.Size()
-
- if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil {
- return err
- }
- r.nic.stats.Tx.Packets.Increment()
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- return nil
+ return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0bf20c0e1..3a07577c8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -436,9 +436,9 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
- // forwarder holds the packets that wait for their link-address resolutions
- // to complete, and forwards them when each resolution is done.
- forwarder *forwardQueue
+ // linkResQueue holds packets that are waiting for link resolution to
+ // complete.
+ linkResQueue packetsPendingLinkResolution
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
@@ -550,8 +550,8 @@ type TransportEndpointInfo struct {
// incompatible with the receiver.
//
// Preconditon: the parent endpoint mu must be held while calling this method.
-func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
+func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := t.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
netProto = header.IPv4ProtocolNumber
@@ -565,7 +565,7 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
}
- switch len(e.ID.LocalAddress) {
+ switch len(t.ID.LocalAddress) {
case header.IPv4AddressSize:
if len(addr.Addr) == header.IPv6AddressSize {
return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
@@ -577,8 +577,8 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
switch {
- case netProto == e.NetProto:
- case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ case netProto == t.NetProto:
+ case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber:
if v6only {
return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
}
@@ -640,7 +640,6 @@ func New(opts Options) *Stack {
useNeighborCache: opts.UseNeighborCache,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
- forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
@@ -653,6 +652,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
+ s.linkResQueue.init()
// Add specified network protocols.
for _, netProtoFactory := range opts.NetworkProtocols {
@@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
-// GetNICByName gets the NIC specified by name.
-func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+// GetLinkEndpointByName gets the link endpoint specified by name.
+func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint {
s.mu.RLock()
defer s.mu.RUnlock()
for _, nic := range s.nics {
if nic.Name() == name {
- return nic, true
+ return nic.LinkEndpoint
}
}
- return nil, false
+ return nil
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
}
nics[id] = NICInfo{
Name: nic.name,
- LinkAddress: nic.linkEP.LinkAddress(),
+ LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
- MTU: nic.linkEP.MTU(),
+ MTU: nic.LinkEndpoint.MTU(),
Stats: nic.stats,
Context: nic.context,
- ARPHardwareType: nic.linkEP.ARPHardwareType(),
+ ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
}
}
return nics
@@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1539,7 +1539,7 @@ func (s *Stack) Wait() {
s.mu.RLock()
defer s.mu.RUnlock()
for _, n := range s.nics {
- n.linkEP.Wait()
+ n.LinkEndpoint.Wait()
}
}
@@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
// Add our own fake ethernet header.
ethFields := header.EthernetFields{
- SrcAddr: nic.linkEP.LinkAddress(),
+ SrcAddr: nic.LinkEndpoint.LinkAddress(),
DstAddr: dst,
Type: netProto,
}
@@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
vv := buffer.View(fakeHeader).ToVectorisedView()
vv.Append(payload)
- if err := nic.linkEP.WriteRawPacket(vv); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
return err
}
@@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView)
return tcpip.ErrUnknownDevice
}
- if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
return err
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index aa20f750b..38994cca1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,7 +21,6 @@ import (
"bytes"
"fmt"
"math"
- "net"
"sort"
"testing"
"time"
@@ -35,7 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -77,10 +75,9 @@ type fakeNetworkEndpoint struct {
enabled bool
}
- nicID tcpip.NICID
+ nic stack.NetworkInterface
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- ep stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
@@ -103,7 +100,7 @@ func (f *fakeNetworkEndpoint) Disable() {
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
@@ -135,7 +132,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fakeNetHeaderLen
+ return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -164,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
return nil
}
- return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -216,10 +213,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &fakeNetworkEndpoint{
- nicID: nic.ID(),
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
@@ -2106,7 +2102,7 @@ func TestNICStats(t *testing.T) {
t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
+ if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
@@ -3502,52 +3498,6 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
}
-func TestResolveWith(t *testing.T) {
- const (
- unspecifiedNICID = 0
- nicID = 1
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- })
- ep := channel.New(0, defaultMTU, "")
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- addr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
- PrefixLen: 24,
- },
- }
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
-
- remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4())
- r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
- }
- defer r.Release()
-
- // Should initially require resolution.
- if !r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = false, want = true")
- }
-
- // Manually resolving the route should no longer require resolution.
- r.ResolveWith("\x01")
- if r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = true, want = false")
- }
-}
-
// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
// associated address is removed should not cause a panic.
func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 06c7a3cd3..a4f141253 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -6,6 +6,8 @@ go_test(
name = "integration_test",
size = "small",
srcs = [
+ "forward_test.go",
+ "link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
],
@@ -15,6 +17,8 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/pipe",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
new file mode 100644
index 000000000..ffd38ee1a
--- /dev/null
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -0,0 +1,378 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestForwarding(t *testing.T) {
+ const (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
+ routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1NICID = 1
+ routerNICID1 = 2
+ routerNICID2 = 3
+ host2NICID = 4
+
+ listenPort = 8080
+ )
+
+ host1IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ routerNIC1IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ routerNIC2IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host2IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host1IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ routerNIC1IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ routerNIC2IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("b::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("b::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ {
+ name: "IPv6 host2 server with host1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired)
+ routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+
+ if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
+ }
+
+ if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ })
+ routerStack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ tcpip.Route{
+ Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ tcpip.Route{
+ Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ tcpip.Route{
+ Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ })
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack)
+ defer epsAndAddrs.serverEP.Close()
+ defer epsAndAddrs.clientEP.Close()
+
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) {
+ t.Helper()
+
+ dataPayload := tcpip.SlicePayload(data)
+ wOpts := tcpip.WriteOptions{To: to}
+ n, ch, err := ep.Write(dataPayload, wOpts)
+ if err == tcpip.ErrNoLinkAddress {
+ // Wait for link resolution to complete.
+ <-ch
+
+ n, _, err = ep.Write(dataPayload, wOpts)
+ } else if err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ }
+
+ if err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ }
+ }
+
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data, &serverAddr)
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress {
+ t.Helper()
+
+ // Wait for the endpoint to be readable.
+ <-ch
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+
+ if diff := cmp.Diff(v, buffer.View(data)); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != expectedFrom {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ return addr
+ }
+
+ addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr)
+ // Unspecify the NIC since NIC IDs are meaningless across stacks.
+ addr.NIC = 0
+
+ data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ write(epsAndAddrs.serverEP, data, &addr)
+ addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr)
+ if addr.Port != listenPort {
+ t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
new file mode 100644
index 000000000..bf3a6f6ee
--- /dev/null
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -0,0 +1,219 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+)
+
+// TestPing tests that two hosts can ping eachother when link resolution is
+// enabled.
+func TestPing(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+
+ // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
+ // request/reply packets.
+ icmpDataOffset = 8
+ )
+
+ tests := []struct {
+ name string
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ icmpBuf func(*testing.T) buffer.View
+ }{
+ {
+ name: "IPv4 Ping",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ icmpBuf: func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv4Echo)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ },
+ },
+ {
+ name: "IPv6 Ping",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ icmpBuf: func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv6EchoRequest)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+
+ if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, waiterCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
+
+ // The first write should trigger link resolution.
+ icmpBuf := test.icmpBuf(t)
+ wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
+ if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress)
+ } else {
+ // Wait for link resolution to complete.
+ <-ch
+ }
+ if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ } else if want := int64(len(icmpBuf)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ }
+
+ // Wait for the endpoint to be readable.
+ <-waiterCH
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+ if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != test.remoteAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index cddedb686..b4604ba35 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1432,7 +1432,9 @@ func TestNoChecksum(t *testing.T) {
var _ stack.NetworkInterface = (*testInterface)(nil)
-type testInterface struct{}
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
func (*testInterface) ID() tcpip.NICID {
return 0
@@ -1450,10 +1452,6 @@ func (*testInterface) Enabled() bool {
return true
}
-func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
- return nil
-}
-
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {