summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/netfilter.go41
-rw-r--r--pkg/abi/linux/netfilter_test.go1
-rw-r--r--pkg/sentry/control/proc.go6
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/control.go2
-rw-r--r--pkg/sentry/fs/host/file.go10
-rw-r--r--pkg/sentry/fs/host/inode_test.go3
-rw-r--r--pkg/sentry/fs/host/wait_test.go3
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go5
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD1
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go12
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go12
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go7
-rw-r--r--pkg/sentry/fsimpl/host/host.go125
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go13
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go7
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go8
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go11
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go9
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go7
-rw-r--r--pkg/sentry/kernel/task.go12
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go7
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go128
-rw-r--r--pkg/sentry/socket/netstack/provider.go6
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/anonfs.go22
-rw-r--r--pkg/sentry/vfs/filesystem.go19
-rw-r--r--pkg/sentry/vfs/filesystem_type.go3
-rw-r--r--pkg/sentry/vfs/mount.go7
-rw-r--r--pkg/sentry/vfs/pathname.go43
-rw-r--r--pkg/sentry/vfs/vfs.go46
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/buffer/view.go20
-rw-r--r--pkg/tcpip/header/BUILD3
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go531
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go957
-rw-r--r--pkg/tcpip/network/hash/hash.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go15
-rw-r--r--pkg/tcpip/network/ipv6/BUILD3
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go190
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go981
-rw-r--r--pkg/tcpip/stack/ndp.go29
-rw-r--r--pkg/tcpip/stack/nic.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go9
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go107
-rw-r--r--pkg/tcpip/stack/transport_test.go2
-rw-r--r--pkg/tcpip/tcpip.go12
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go12
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go9
-rw-r--r--pkg/tcpip/transport/tcp/accept.go5
-rw-r--r--pkg/tcpip/transport/tcp/connect.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go12
-rw-r--r--runsc/boot/fds.go5
-rwxr-xr-xscripts/iptables_tests.sh4
-rw-r--r--test/iptables/filter_input.go22
-rw-r--r--test/iptables/filter_output.go143
-rw-r--r--test/iptables/iptables_test.go30
-rw-r--r--test/syscalls/linux/BUILD5
-rw-r--r--test/syscalls/linux/sysret.cc35
70 files changed, 3607 insertions, 171 deletions
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 80dc09aa9..a8d4f9d69 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -509,3 +509,44 @@ const (
// Enable all flags.
XT_UDP_INV_MASK = 0x03
)
+
+// IPTOwnerInfo holds data for matching packets with owner. It corresponds
+// to struct ipt_owner_info in libxt_owner.c of iptables binary.
+type IPTOwnerInfo struct {
+ // UID is user id which created the packet.
+ UID uint32
+
+ // GID is group id which created the packet.
+ GID uint32
+
+ // PID is process id of the process which created the packet.
+ PID uint32
+
+ // SID is session id which created the packet.
+ SID uint32
+
+ // Comm is the command name which created the packet.
+ Comm [16]byte
+
+ // Match is used to match UID/GID of the socket. See the
+ // XT_OWNER_* flags below.
+ Match uint8
+
+ // Invert flips the meaning of Match field.
+ Invert uint8
+}
+
+// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo.
+const SizeOfIPTOwnerInfo = 34
+
+// Flags in IPTOwnerInfo.Match. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_owner.h.
+const (
+ // Match the UID of the packet.
+ XT_OWNER_UID = 1 << 0
+ // Match the GID of the packet.
+ XT_OWNER_GID = 1 << 1
+ // Match if the socket exists for the packet. Forwarded
+ // packets do not have an associated socket.
+ XT_OWNER_SOCKET = 1 << 2
+)
diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go
index 21e237f92..565dd550e 100644
--- a/pkg/abi/linux/netfilter_test.go
+++ b/pkg/abi/linux/netfilter_test.go
@@ -29,6 +29,7 @@ func TestSizes(t *testing.T) {
{IPTGetEntries{}, SizeOfIPTGetEntries},
{IPTGetinfo{}, SizeOfIPTGetinfo},
{IPTIP{}, SizeOfIPTIP},
+ {IPTOwnerInfo{}, SizeOfIPTOwnerInfo},
{IPTReplace{}, SizeOfIPTReplace},
{XTCounters{}, SizeOfXTCounters},
{XTEntryMatch{}, SizeOfXTEntryMatch},
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 5457ba5e7..b51fb3959 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -224,8 +224,6 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
}
}
- mounter := fs.FileOwnerFromContext(ctx)
-
// TODO(gvisor.dev/issue/1623): Use host FD when supported in VFS2.
var ttyFile *fs.File
for appFD, hostFile := range args.FilePayload.Files {
@@ -235,7 +233,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
// Import the file as a host TTY file.
if ttyFile == nil {
var err error
- appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), mounter, true /* isTTY */)
+ appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), true /* isTTY */)
if err != nil {
return nil, 0, nil, err
}
@@ -254,7 +252,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
} else {
// Import the file as a regular host file.
var err error
- appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), mounter, false /* isTTY */)
+ appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), false /* isTTY */)
if err != nil {
return nil, 0, nil, err
}
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 011625c80..aabce6cc9 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -71,7 +71,6 @@ go_test(
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/sentry/contexttest",
- "//pkg/sentry/fs",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go
index cd84e1337..52c0504b6 100644
--- a/pkg/sentry/fs/host/control.go
+++ b/pkg/sentry/fs/host/control.go
@@ -78,7 +78,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*fs.File {
}
// Create the file backed by hostFD.
- file, err := NewFile(ctx, fd, fs.FileOwnerFromContext(ctx))
+ file, err := NewFile(ctx, fd)
if err != nil {
ctx.Warningf("Error creating file from host FD: %v", err)
break
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
index 034862694..3e48b8b2c 100644
--- a/pkg/sentry/fs/host/file.go
+++ b/pkg/sentry/fs/host/file.go
@@ -60,8 +60,8 @@ var _ fs.FileOperations = (*fileOperations)(nil)
// The returned File cannot be saved, since there is no guarantee that the same
// FD will exist or represent the same file at time of restore. If such a
// guarantee does exist, use ImportFile instead.
-func NewFile(ctx context.Context, fd int, mounter fs.FileOwner) (*fs.File, error) {
- return newFileFromDonatedFD(ctx, fd, mounter, false, false)
+func NewFile(ctx context.Context, fd int) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, false, false)
}
// ImportFile creates a new File backed by the provided host file descriptor.
@@ -71,13 +71,13 @@ func NewFile(ctx context.Context, fd int, mounter fs.FileOwner) (*fs.File, error
// If the returned file is saved, it will be restored by re-importing the FD
// originally passed to ImportFile. It is the restorer's responsibility to
// ensure that the FD represents the same file.
-func ImportFile(ctx context.Context, fd int, mounter fs.FileOwner, isTTY bool) (*fs.File, error) {
- return newFileFromDonatedFD(ctx, fd, mounter, true, isTTY)
+func ImportFile(ctx context.Context, fd int, isTTY bool) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, true, isTTY)
}
// newFileFromDonatedFD returns an fs.File from a donated FD. If the FD is
// saveable, then saveable is true.
-func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner, saveable, isTTY bool) (*fs.File, error) {
+func newFileFromDonatedFD(ctx context.Context, donated int, saveable, isTTY bool) (*fs.File, error) {
var s syscall.Stat_t
if err := syscall.Fstat(donated, &s); err != nil {
return nil, err
diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go
index 4c374681c..c507f57eb 100644
--- a/pkg/sentry/fs/host/inode_test.go
+++ b/pkg/sentry/fs/host/inode_test.go
@@ -19,7 +19,6 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
)
// TestCloseFD verifies fds will be closed.
@@ -33,7 +32,7 @@ func TestCloseFD(t *testing.T) {
// Use the write-end because we will detect if it's closed on the read end.
ctx := contexttest.Context(t)
- file, err := NewFile(ctx, p[1], fs.RootOwner)
+ file, err := NewFile(ctx, p[1])
if err != nil {
t.Fatalf("Failed to create File: %v", err)
}
diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go
index d49c3a635..ce397a5e3 100644
--- a/pkg/sentry/fs/host/wait_test.go
+++ b/pkg/sentry/fs/host/wait_test.go
@@ -20,7 +20,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -34,7 +33,7 @@ func TestWait(t *testing.T) {
defer syscall.Close(fds[1])
ctx := contexttest.Context(t)
- file, err := NewFile(ctx, fds[0], fs.RootOwner)
+ file, err := NewFile(ctx, fds[0])
if err != nil {
syscall.Close(fds[0])
t.Fatalf("NewFile failed: %v", err)
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index abd4f24e7..64f1b142c 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -42,6 +42,11 @@ type FilesystemType struct {
root *vfs.Dentry
}
+// Name implements vfs.FilesystemType.Name.
+func (*FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fst *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fst.initOnce.Do(func() {
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index 6f78f478f..d83d75b3d 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -45,6 +45,7 @@ go_library(
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
"//pkg/sync",
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index 373d23b74..7176af6d1 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -30,6 +30,9 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// Name is the name of this filesystem.
+const Name = "ext"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
@@ -91,8 +94,13 @@ func isCompatible(sb disklayout.SuperBlock) bool {
return true
}
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
-func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// TODO(b/134676337): Ensure that the user is mounting readonly. If not,
// EACCESS should be returned according to mount(2). Filesystem independent
// flags (like readonly) are currently not available in pkg/sentry/vfs.
@@ -103,7 +111,7 @@ func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFile
}
fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)}
- fs.vfsfs.Init(vfsObj, &fs)
+ fs.vfsfs.Init(vfsObj, &fsType, &fs)
fs.sb, err = readSuperBlock(dev)
if err != nil {
return nil, nil, err
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 8497be615..48eaccdbc 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -463,6 +464,17 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return syserror.EROFS
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+
+ // TODO(b/134676337): Support sockets.
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
_, _, err := fs.walk(rp, false)
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 4ba76a1e8..d15a36709 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -46,6 +46,7 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 1e43df9ec..269624362 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -1059,6 +1060,13 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return fs.unlinkAt(ctx, rp, false /* dir */)
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+//
+// TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
var ds *[]*dentry
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index cf276a417..8e41b6b1c 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -199,6 +199,11 @@ const (
InteropModeShared
)
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
@@ -374,7 +379,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
dentries: make(map[*dentry]struct{}),
specialFileFDs: make(map[*specialFileFD]struct{}),
}
- fs.vfsfs.Init(vfsObj, fs)
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
// Construct the root dentry.
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index a54985ef5..7d9dcd4c9 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -38,6 +38,19 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (filesystemType) GetFilesystem(context.Context, *vfs.VirtualFilesystem, *auth.Credentials, string, vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("cannot instaniate a host filesystem")
+}
+
+// Name implements FilesystemType.Name.
+func (filesystemType) Name() string {
+ return "none"
+}
+
// filesystem implements vfs.FilesystemImpl.
type filesystem struct {
kernfs.Filesystem
@@ -46,7 +59,7 @@ type filesystem struct {
// NewMount returns a new disconnected mount in vfsObj that may be passed to ImportFD.
func NewMount(vfsObj *vfs.VirtualFilesystem) (*vfs.Mount, error) {
fs := &filesystem{}
- fs.Init(vfsObj)
+ fs.Init(vfsObj, &filesystemType{})
vfsfs := fs.VFSFilesystem()
// NewDisconnectedMount will take an additional reference on vfsfs.
defer vfsfs.DecRef()
@@ -54,7 +67,7 @@ func NewMount(vfsObj *vfs.VirtualFilesystem) (*vfs.Mount, error) {
}
// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
-func ImportFD(mnt *vfs.Mount, hostFD int, ownerUID auth.KUID, ownerGID auth.KGID, isTTY bool) (*vfs.FileDescription, error) {
+func ImportFD(mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) {
fs, ok := mnt.Filesystem().Impl().(*kernfs.Filesystem)
if !ok {
return nil, fmt.Errorf("can't import host FDs into filesystems of type %T", mnt.Filesystem().Impl())
@@ -78,8 +91,6 @@ func ImportFD(mnt *vfs.Mount, hostFD int, ownerUID auth.KUID, ownerGID auth.KGID
canMap: canMap(uint32(fileType)),
ino: fs.NextIno(),
mode: fileMode,
- uid: ownerUID,
- gid: ownerGID,
// For simplicity, set offset to 0. Technically, we should
// only set to 0 on files that are not seekable (sockets, pipes, etc.),
// and use the offset from the host fd otherwise.
@@ -135,17 +146,20 @@ type inode struct {
// This field is initialized at creation time and is immutable.
ino uint64
- // TODO(gvisor.dev/issue/1672): protect mode, uid, and gid with mutex.
+ // modeMu protects mode.
+ modeMu sync.Mutex
- // mode is the file mode of this inode. Note that this value may become out
- // of date if the mode is changed on the host, e.g. with chmod.
+ // mode is a cached version of the file mode on the host. Note that it may
+ // become out of date if the mode is changed on the host, e.g. with chmod.
+ //
+ // Generally, it is better to retrieve the mode from the host through an
+ // fstat syscall. We only use this value in inode.Mode(), which cannot
+ // return an error, if the syscall to host fails.
+ //
+ // FIXME(b/152294168): Plumb error into Inode.Mode() return value so we
+ // can get rid of this.
mode linux.FileMode
- // uid and gid of the file owner. Note that these refer to the owner of the
- // file created on import, not the fd on the host.
- uid auth.KUID
- gid auth.KGID
-
// offsetMu protects offset.
offsetMu sync.Mutex
@@ -168,12 +182,35 @@ func fileFlagsFromHostFD(fd int) (int, error) {
// CheckPermissions implements kernfs.Inode.
func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
- return vfs.GenericCheckPermissions(creds, ats, i.mode, i.uid, i.gid)
+ mode, uid, gid, err := i.getPermissions()
+ if err != nil {
+ return err
+ }
+ return vfs.GenericCheckPermissions(creds, ats, mode, uid, gid)
}
// Mode implements kernfs.Inode.
func (i *inode) Mode() linux.FileMode {
- return i.mode
+ mode, _, _, err := i.getPermissions()
+ if err != nil {
+ return i.mode
+ }
+
+ return linux.FileMode(mode)
+}
+
+func (i *inode) getPermissions() (linux.FileMode, auth.KUID, auth.KGID, error) {
+ // Retrieve metadata.
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return 0, 0, 0, err
+ }
+
+ // Update cached mode.
+ i.modeMu.Lock()
+ i.mode = linux.FileMode(s.Mode)
+ i.modeMu.Unlock()
+ return linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid), nil
}
// Stat implements kernfs.Inode.
@@ -213,45 +250,51 @@ func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, erro
ls.Attributes = s.Attributes
ls.AttributesMask = s.Attributes_mask
- if mask|linux.STATX_TYPE != 0 {
+ if mask&linux.STATX_TYPE != 0 {
ls.Mode |= s.Mode & linux.S_IFMT
}
- if mask|linux.STATX_MODE != 0 {
+ if mask&linux.STATX_MODE != 0 {
ls.Mode |= s.Mode &^ linux.S_IFMT
}
- if mask|linux.STATX_NLINK != 0 {
+ if mask&linux.STATX_NLINK != 0 {
ls.Nlink = s.Nlink
}
- if mask|linux.STATX_ATIME != 0 {
+ if mask&linux.STATX_UID != 0 {
+ ls.UID = s.Uid
+ }
+ if mask&linux.STATX_GID != 0 {
+ ls.GID = s.Gid
+ }
+ if mask&linux.STATX_ATIME != 0 {
ls.Atime = unixToLinuxStatxTimestamp(s.Atime)
}
- if mask|linux.STATX_BTIME != 0 {
+ if mask&linux.STATX_BTIME != 0 {
ls.Btime = unixToLinuxStatxTimestamp(s.Btime)
}
- if mask|linux.STATX_CTIME != 0 {
+ if mask&linux.STATX_CTIME != 0 {
ls.Ctime = unixToLinuxStatxTimestamp(s.Ctime)
}
- if mask|linux.STATX_MTIME != 0 {
+ if mask&linux.STATX_MTIME != 0 {
ls.Mtime = unixToLinuxStatxTimestamp(s.Mtime)
}
- if mask|linux.STATX_SIZE != 0 {
+ if mask&linux.STATX_SIZE != 0 {
ls.Size = s.Size
}
- if mask|linux.STATX_BLOCKS != 0 {
+ if mask&linux.STATX_BLOCKS != 0 {
ls.Blocks = s.Blocks
}
- // Use our own internal inode number and file owner.
- if mask|linux.STATX_INO != 0 {
+ // Use our own internal inode number.
+ if mask&linux.STATX_INO != 0 {
ls.Ino = i.ino
}
- if mask|linux.STATX_UID != 0 {
- ls.UID = uint32(i.uid)
- }
- if mask|linux.STATX_GID != 0 {
- ls.GID = uint32(i.gid)
- }
+ // Update cached mode.
+ if (mask&linux.STATX_TYPE != 0) && (mask&linux.STATX_MODE != 0) {
+ i.modeMu.Lock()
+ i.mode = linux.FileMode(s.Mode)
+ i.modeMu.Unlock()
+ }
return ls, nil
}
@@ -274,6 +317,8 @@ func (i *inode) fstat(opts vfs.StatOptions) (linux.Statx, error) {
Mask: linux.STATX_BASIC_STATS,
Blksize: uint32(s.Blksize),
Nlink: uint32(s.Nlink),
+ UID: s.Uid,
+ GID: s.Gid,
Mode: uint16(s.Mode),
Size: uint64(s.Size),
Blocks: uint64(s.Blocks),
@@ -282,15 +327,13 @@ func (i *inode) fstat(opts vfs.StatOptions) (linux.Statx, error) {
Mtime: timespecToStatxTimestamp(s.Mtim),
}
- // Use our own internal inode number and file owner.
+ // Use our own internal inode number.
//
// TODO(gvisor.dev/issue/1672): Use a kernfs-specific device number as well.
// If we use the device number from the host, it may collide with another
// sentry-internal device number. We handle device/inode numbers without
// relying on the host to prevent collisions.
ls.Ino = i.ino
- ls.UID = uint32(i.uid)
- ls.GID = uint32(i.gid)
return ls, nil
}
@@ -306,7 +349,11 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
if m&^(linux.STATX_MODE|linux.STATX_SIZE|linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, &s, i.Mode(), i.uid, i.gid); err != nil {
+ mode, uid, gid, err := i.getPermissions()
+ if err != nil {
+ return err
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &s, mode.Permissions(), uid, gid); err != nil {
return err
}
@@ -314,7 +361,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
if err := syscall.Fchmod(i.hostFD, uint32(s.Mode)); err != nil {
return err
}
+ i.modeMu.Lock()
i.mode = linux.FileMode(s.Mode)
+ i.modeMu.Unlock()
}
if m&linux.STATX_SIZE != 0 {
if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
@@ -351,7 +400,11 @@ func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptio
}
func (i *inode) open(d *vfs.Dentry, mnt *vfs.Mount) (*vfs.FileDescription, error) {
- fileType := i.mode.FileType()
+ mode, _, _, err := i.getPermissions()
+ if err != nil {
+ return nil, err
+ }
+ fileType := mode.FileType()
if fileType == syscall.S_IFSOCK {
if i.isTTY {
return nil, errors.New("cannot use host socket as TTY")
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index e73f1f857..b3d6299d0 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -35,6 +35,7 @@ go_library(
"//pkg/refs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 31da8b511..a429fa23d 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -728,6 +729,18 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return nil
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return nil, err
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
fs.mu.RLock()
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 794e38908..2cefef020 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -63,9 +63,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-// FilesystemType implements vfs.FilesystemType.
-type FilesystemType struct{}
-
// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory
// filesystem. Concrete implementations are expected to embed this in their own
// Filesystem type.
@@ -138,8 +135,8 @@ func (fs *Filesystem) processDeferredDecRefsLocked() {
// Init initializes a kernfs filesystem. This should be called from during
// vfs.FilesystemType.NewFilesystem for the concrete filesystem embedding
// kernfs.
-func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem) {
- fs.vfsfs.Init(vfsObj, fs)
+func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem, fsType vfs.FilesystemType) {
+ fs.vfsfs.Init(vfsObj, fsType, fs)
}
// VFSFilesystem returns the generic vfs filesystem object.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index fb0d25ad7..465451f35 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -187,9 +187,13 @@ func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, err
return nil, syserror.EPERM
}
-func (fst *fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (fsType) Name() string {
+ return "kernfs"
+}
+
+func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
- fs.Init(vfsObj)
+ fs.Init(vfsObj, &fst)
root := fst.rootFn(creds, fs)
return fs.VFSFilesystem(), root.VFSDentry(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 5c19d5522..104fc9030 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -36,8 +36,13 @@ type FilesystemType struct{}
var _ vfs.FilesystemType = (*FilesystemType)(nil)
-// GetFilesystem implements vfs.FilesystemType.
-func (ft *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
k := kernel.KernelFromContext(ctx)
if k == nil {
return nil, nil, fmt.Errorf("procfs requires a kernel")
@@ -48,7 +53,7 @@ func (ft *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virtual
}
procfs := &kernfs.Filesystem{}
- procfs.VFSFilesystem().Init(vfsObj, procfs)
+ procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
var cgroups map[string]string
if opts.InternalData != nil {
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 7abfd62f2..5c617270e 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -39,10 +39,15 @@ type filesystem struct {
kernfs.Filesystem
}
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
-func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
- fs.Filesystem.Init(vfsObj)
+ fs.Filesystem.Init(vfsObj, &fsType)
k := kernel.KernelFromContext(ctx)
maxCPUCores := k.ApplicationCores()
defaultSysDirMode := linux.FileMode(0755)
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 57abd5583..6ea35affb 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -46,6 +46,7 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/sentry/vfs/lock",
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 12cc64385..e678ecc37 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -656,6 +657,13 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return nil
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+//
+// TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
fs.mu.RLock()
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 2f9e6c876..b07b0dbae 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -63,6 +63,11 @@ type filesystem struct {
nextInoMinusOne uint64 // accessed using atomic memory operations
}
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
@@ -74,7 +79,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
memFile: memFileProvider.MemoryFile(),
clock: clock,
}
- fs.vfsfs.Init(vfsObj, &fs)
+ fs.vfsfs.Init(vfsObj, &fstype, &fs)
root := fs.newDentry(fs.newDirectory(creds, 01777))
return &fs.vfsfs, &root.vfsd, nil
}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 8452ddf5b..d6546735e 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -863,3 +863,15 @@ func (t *Task) SetOOMScoreAdj(adj int32) error {
atomic.StoreInt32(&t.tg.oomScoreAdj, adj)
return nil
}
+
+// UID returns t's uid.
+// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
+func (t *Task) UID() uint32 {
+ return uint32(t.Credentials().EffectiveKUID)
+}
+
+// GID returns t's gid.
+// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
+func (t *Task) GID() uint32 {
+ return uint32(t.Credentials().EffectiveKGID)
+}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index e801abeb8..721094bbf 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"extensions.go",
"netfilter.go",
+ "owner_matcher.go",
"targets.go",
"tcp_matcher.go",
"udp_matcher.go",
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 55bcc3ace..878f81fd5 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -517,11 +517,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
}
// TODO(gvisor.dev/issue/170): Support other chains.
- // Since we only support modifying the INPUT chain and redirect for
- // PREROUTING chain right now, make sure all other chains point to
- // ACCEPT rules.
+ // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
+ // make sure all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook != stack.Input && hook != stack.Prerouting {
+ if hook == stack.Forward || hook == stack.Postrouting {
if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
new file mode 100644
index 000000000..5949a7c29
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -0,0 +1,128 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameOwner = "owner"
+
+func init() {
+ registerMatchMaker(ownerMarshaler{})
+}
+
+// ownerMarshaler implements matchMaker for owner matching.
+type ownerMarshaler struct{}
+
+// name implements matchMaker.name.
+func (ownerMarshaler) name() string {
+ return matcherNameOwner
+}
+
+// marshal implements matchMaker.marshal.
+func (ownerMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*OwnerMatcher)
+ iptOwnerInfo := linux.IPTOwnerInfo{
+ UID: matcher.uid,
+ GID: matcher.gid,
+ }
+
+ // Support for UID match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if matcher.matchUID {
+ iptOwnerInfo.Match = linux.XT_OWNER_UID
+ } else if matcher.matchGID {
+ panic("GID match is not supported.")
+ } else {
+ panic("UID match is not set.")
+ }
+
+ buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfIPTOwnerInfo {
+ return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.IPTOwnerInfo
+ binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+
+ if matchData.Invert != 0 {
+ return nil, fmt.Errorf("invert flag is not supported for owner match")
+ }
+
+ // Support for UID match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if matchData.Match&linux.XT_OWNER_UID != linux.XT_OWNER_UID {
+ return nil, fmt.Errorf("owner match is only supported for uid")
+ }
+
+ // Check Flags.
+ var owner OwnerMatcher
+ owner.uid = matchData.UID
+ owner.gid = matchData.GID
+ owner.matchUID = true
+
+ return &owner, nil
+}
+
+type OwnerMatcher struct {
+ uid uint32
+ gid uint32
+ matchUID bool
+ matchGID bool
+ invert uint8
+}
+
+// Name implements Matcher.Name.
+func (*OwnerMatcher) Name() string {
+ return matcherNameOwner
+}
+
+// Match implements Matcher.Match.
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+ // Support only for OUTPUT chain.
+ // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
+ if hook != stack.Output {
+ return false, true
+ }
+
+ // If the packet owner is not set, drop the packet.
+ // Support for uid match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if pkt.Owner == nil || !om.matchUID {
+ return false, true
+ }
+
+ // TODO(gvisor.dev/issue/170): Need to add tests to verify
+ // drop rule when packet UID does not match owner matcher UID.
+ if pkt.Owner.UID() != om.uid {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 5f181f017..eb090e79b 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -126,6 +126,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
} else {
ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
}
if e != nil {
return nil, syserr.TranslateNetstackError(e)
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index a2a06fc8f..bf4d27c7d 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -53,6 +53,7 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index a62e43589..d1f6dfb45 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -50,6 +51,19 @@ const (
anonFileGID = auth.RootKGID
)
+// anonFilesystemType implements FilesystemType.
+type anonFilesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (anonFilesystemType) GetFilesystem(context.Context, *VirtualFilesystem, *auth.Credentials, string, GetFilesystemOptions) (*Filesystem, *Dentry, error) {
+ panic("cannot instaniate an anon filesystem")
+}
+
+// Name implemenents FilesystemType.Name.
+func (anonFilesystemType) Name() string {
+ return "none"
+}
+
// anonFilesystem is the implementation of FilesystemImpl that backs
// VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
//
@@ -222,6 +236,14 @@ func (fs *anonFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error
return syserror.EPERM
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath) (transport.BoundEndpoint, error) {
+ if !rp.Final() {
+ return nil, syserror.ENOTDIR
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements FilesystemImpl.ListxattrAt.
func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) {
if !rp.Done() {
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 332decce6..cd34782ff 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
)
// A Filesystem is a tree of nodes represented by Dentries, which forms part of
@@ -41,21 +42,30 @@ type Filesystem struct {
// immutable.
vfs *VirtualFilesystem
+ // fsType is the FilesystemType of this Filesystem.
+ fsType FilesystemType
+
// impl is the FilesystemImpl associated with this Filesystem. impl is
// immutable. This should be the last field in Dentry.
impl FilesystemImpl
}
// Init must be called before first use of fs.
-func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, impl FilesystemImpl) {
+func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) {
fs.refs = 1
fs.vfs = vfsObj
+ fs.fsType = fsType
fs.impl = impl
vfsObj.filesystemsMu.Lock()
vfsObj.filesystems[fs] = struct{}{}
vfsObj.filesystemsMu.Unlock()
}
+// FilesystemType returns the FilesystemType for this Filesystem.
+func (fs *Filesystem) FilesystemType() FilesystemType {
+ return fs.fsType
+}
+
// VirtualFilesystem returns the containing VirtualFilesystem.
func (fs *Filesystem) VirtualFilesystem() *VirtualFilesystem {
return fs.vfs
@@ -460,6 +470,11 @@ type FilesystemImpl interface {
// RemovexattrAt returns ENOTSUP.
RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
+ // BoundEndpointAt returns the Unix socket endpoint bound at the path rp.
+ //
+ // - If a non-socket file exists at rp, then BoundEndpointAt returns ECONNREFUSED.
+ BoundEndpointAt(ctx context.Context, rp *ResolvingPath) (transport.BoundEndpoint, error)
+
// PrependPath prepends a path from vd to vd.Mount().Root() to b.
//
// If vfsroot.Ok(), it is the contextual VFS root; if it is encountered
@@ -482,7 +497,7 @@ type FilesystemImpl interface {
// Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
- // TODO: inotify_add_watch(); bind()
+ // TODO: inotify_add_watch()
}
// PrependPathAtVFSRootError is returned by implementations of
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index bb9cada81..f2298f7f6 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -30,6 +30,9 @@ type FilesystemType interface {
// along with its mount root. A reference is taken on the returned
// Filesystem and Dentry.
GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error)
+
+ // Name returns the name of this FilesystemType.
+ Name() string
}
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem.
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 05f6233f9..4b68cabda 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -24,6 +24,9 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// lastMountID is used to allocate mount ids. Must be accessed atomically.
+var lastMountID uint64
+
// A Mount is a replacement of a Dentry (Mount.key.point) from one Filesystem
// (Mount.key.parent.fs) with a Dentry (Mount.root) from another Filesystem
// (Mount.fs), which applies to path resolution in the context of a particular
@@ -48,6 +51,9 @@ type Mount struct {
fs *Filesystem
root *Dentry
+ // ID is the immutable mount ID.
+ ID uint64
+
// key is protected by VirtualFilesystem.mountMu and
// VirtualFilesystem.mounts.seq, and may be nil. References are held on
// key.parent and key.point if they are not nil.
@@ -87,6 +93,7 @@ type Mount struct {
func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount {
mnt := &Mount{
+ ID: atomic.AddUint64(&lastMountID, 1),
vfs: vfs,
fs: fs,
root: root,
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
index b318c681a..f21a88034 100644
--- a/pkg/sentry/vfs/pathname.go
+++ b/pkg/sentry/vfs/pathname.go
@@ -90,6 +90,49 @@ loop:
return b.String(), nil
}
+// PathnameReachable returns an absolute pathname to vd, consistent with
+// Linux's __d_path() (as used by seq_path_root()). If vfsroot.Ok() and vd is
+// not reachable from vfsroot, such that seq_path_root() would return SEQ_SKIP
+// (causing the entire containing entry to be skipped), PathnameReachable
+// returns ("", nil).
+func (vfs *VirtualFilesystem) PathnameReachable(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef()
+ }
+ }()
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ return "", nil
+ }
+ if haveRef {
+ vd.DecRef()
+ }
+ vd = nextVD
+ haveRef = true
+ case PrependPathAtVFSRootError:
+ break loop
+ case PrependPathAtNonMountRootError, PrependPathSyntheticError:
+ return "", nil
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ return b.String(), nil
+}
+
// PathnameForGetcwd returns an absolute pathname to vd, consistent with
// Linux's sys_getcwd().
func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 03d1fb943..720b90d8f 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -38,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -133,7 +134,7 @@ func (vfs *VirtualFilesystem) Init() error {
anonfs := anonFilesystem{
devMinor: anonfsDevMinor,
}
- anonfs.vfsfs.Init(vfs, &anonfs)
+ anonfs.vfsfs.Init(vfs, &anonFilesystemType{}, &anonfs)
defer anonfs.vfsfs.DecRef()
anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{})
if err != nil {
@@ -230,7 +231,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -271,7 +272,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -307,7 +308,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -340,7 +341,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -350,6 +351,33 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
}
}
+// BoundEndpointAt gets the bound endpoint at the given path, if one exists.
+func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (transport.BoundEndpoint, error) {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return nil, syserror.ECONNREFUSED
+ }
+ return nil, syserror.ENOENT
+ }
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return bep, nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
// OpenAt returns a FileDescription providing access to the file at the given
// path. A reference is taken on the returned FileDescription.
func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
@@ -494,7 +522,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -527,7 +555,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -608,7 +636,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -640,7 +668,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 4b5a0fca6..f86db0999 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -27,6 +27,7 @@ import (
var (
E2BIG = error(syscall.E2BIG)
EACCES = error(syscall.EACCES)
+ EADDRINUSE = error(syscall.EADDRINUSE)
EAGAIN = error(syscall.EAGAIN)
EBADF = error(syscall.EBADF)
EBADFD = error(syscall.EBADFD)
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 17e94c562..8d42cd066 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -15,6 +15,10 @@
// Package buffer provides the implementation of a buffer view.
package buffer
+import (
+ "bytes"
+)
+
// View is a slice of a buffer, with convenience methods.
type View []byte
@@ -45,6 +49,13 @@ func (v *View) CapLength(length int) {
*v = (*v)[:length:length]
}
+// Reader returns a bytes.Reader for v.
+func (v *View) Reader() bytes.Reader {
+ var r bytes.Reader
+ r.Reset(*v)
+ return r
+}
+
// ToVectorisedView returns a VectorisedView containing the receiver.
func (v View) ToVectorisedView() VectorisedView {
return NewVectorisedView(len(v), []View{v})
@@ -162,3 +173,12 @@ func (vv *VectorisedView) AppendView(v View) {
vv.views = append(vv.views, v)
vv.size += len(v)
}
+
+// Readers returns a bytes.Reader for each of vv's views.
+func (vv *VectorisedView) Readers() []bytes.Reader {
+ readers := make([]bytes.Reader, 0, len(vv.views))
+ for _, v := range vv.views {
+ readers = append(readers, v.Reader())
+ }
+ return readers
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 9da0d71f8..7094f3f0b 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -14,6 +14,7 @@ go_library(
"interfaces.go",
"ipv4.go",
"ipv6.go",
+ "ipv6_extension_headers.go",
"ipv6_fragment.go",
"ndp_neighbor_advert.go",
"ndp_neighbor_solicit.go",
@@ -55,11 +56,13 @@ go_test(
size = "small",
srcs = [
"eth_test.go",
+ "ipv6_extension_headers_test.go",
"ndp_test.go",
],
library = ":header",
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"@com_github_google_go-cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
new file mode 100644
index 000000000..1b6c3f328
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -0,0 +1,531 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier.
+type IPv6ExtensionHeaderIdentifier uint8
+
+const (
+ // IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by
+ // Hop Options extension header, as per RFC 8200 section 4.3.
+ IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0
+
+ // IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension
+ // header, as per RFC 8200 section 4.4.
+ IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43
+
+ // IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment
+ // extension header, as per RFC 8200 section 4.5.
+ IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44
+
+ // IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a
+ // Destination Options extension header, as per RFC 8200 section 4.6.
+ IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60
+
+ // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
+ // of an IPv6 payload, as per RFC 8200 section 4.7.
+ IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
+)
+
+const (
+ // ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when
+ // a node encounters an unrecognized option.
+ ipv6UnknownExtHdrOptionActionMask = 192
+
+ // ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard
+ // from the action value for an unrecognized option identifier.
+ ipv6UnknownExtHdrOptionActionShift = 6
+
+ // ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field
+ // within an IPv6RoutingExtHdr.
+ ipv6RoutingExtHdrSegmentsLeftIdx = 1
+
+ // ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the
+ // Fragment Offset field within an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrFragmentOffsetOffset = 0
+
+ // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
+ // discard from the Fragment Offset.
+ ipv6FragmentExtHdrFragmentOffsetShift = 3
+
+ // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
+ // IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrFlagsIdx = 1
+
+ // ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the
+ // flags field of an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrMFlagMask = 1
+
+ // ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification
+ // field within an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrIdentificationOffset = 2
+
+ // ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length
+ // field. That is, given a Length field of 2, the extension header expects
+ // 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for
+ // details about the first 8 bytes' exclusion from the Length field).
+ ipv6ExtHdrLenBytesPerUnit = 8
+
+ // ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an
+ // extension header's Length field following the Length field.
+ //
+ // The Length field excludes the first 8 bytes, but the Next Header and Length
+ // field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes
+ // after the Length field.
+ //
+ // This ensures that every extension header is at least 8 bytes.
+ ipv6ExtHdrLenBytesExcluded = 6
+
+ // IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment
+ // extension header's Fragment Offset field. That is, given a Fragment Offset
+ // of 2, the extension header is indiciating that the fragment's payload
+ // starts at the 16th byte in the reassembled packet.
+ IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
+)
+
+// IPv6PayloadHeader is implemented by the various headers that can be found
+// in an IPv6 payload.
+//
+// These headers include IPv6 extension headers or upper layer data.
+type IPv6PayloadHeader interface {
+ isIPv6PayloadHeader()
+}
+
+// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator
+// encounters a Next Header field it does not recognize as an IPv6 extension
+// header.
+type IPv6RawPayloadHeader struct {
+ Identifier IPv6ExtensionHeaderIdentifier
+ Buf buffer.VectorisedView
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {}
+
+// ipv6OptionsExtHdr is an IPv6 extension header that holds options.
+type ipv6OptionsExtHdr []byte
+
+// Iter returns an iterator over the IPv6 extension header options held in b.
+func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
+ it := IPv6OptionsExtHdrOptionsIterator{}
+ it.reader.Reset(b)
+ return it
+}
+
+// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header
+// options.
+//
+// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last
+// used, no changes to the underlying buffer may happen. Doing so may cause
+// undefined and unexpected behaviour. It is fine to obtain an
+// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then
+// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator
+// obtained before modification is no longer used.
+type IPv6OptionsExtHdrOptionsIterator struct {
+ reader bytes.Reader
+}
+
+// IPv6OptionUnknownAction is the action that must be taken if the processing
+// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
+type IPv6OptionUnknownAction int
+
+const (
+ // IPv6OptionUnknownActionSkip indicates that the unrecognized option must
+ // be skipped and the node should continue processing the header.
+ IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0
+
+ // IPv6OptionUnknownActionDiscard indicates that the packet must be silently
+ // discarded.
+ IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1
+
+ // IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be
+ // discarded and the node must send an ICMP Parameter Problem, Code 2, message
+ // to the packet's source, regardless of whether or not the packet's
+ // Destination was a multicast address.
+ IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2
+
+ // IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the
+ // packet must be discarded and the node must send an ICMP Parameter Problem,
+ // Code 2, message to the packet's source only if the packet's Destination was
+ // not a multicast address.
+ IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3
+)
+
+// IPv6ExtHdrOption is implemented by the various IPv6 extension header options.
+type IPv6ExtHdrOption interface {
+ // UnknownAction returns the action to take in response to an unrecognized
+ // option.
+ UnknownAction() IPv6OptionUnknownAction
+
+ // isIPv6ExtHdrOption is used to "lock" this interface so it is not
+ // implemented by other packages.
+ isIPv6ExtHdrOption()
+}
+
+// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIndentifier uint8
+
+const (
+ // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
+ // provides 1 byte padding, as outlined in RFC 8200 section 4.2.
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+
+ // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
+ // provides variable length byte padding, as outlined in RFC 8200 section 4.2.
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+)
+
+// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
+// header option that is unknown by the parsing utilities.
+type IPv6UnknownExtHdrOption struct {
+ Identifier IPv6ExtHdrOptionIndentifier
+ Data []byte
+}
+
+// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
+func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
+func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
+
+// Next returns the next option in the options data.
+//
+// If the next item is not a known extension header option,
+// IPv6UnknownExtHdrOption will be returned with the option identifier and data.
+//
+// The return is of the format (option, done, error). done will be true when
+// Next is unable to return anything because the iterator has reached the end of
+// the options data, or an error occured.
+func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
+ for {
+ temp, err := i.reader.ReadByte()
+ if err != nil {
+ // If we can't read the first byte of a new option, then we know the
+ // options buffer has been exhausted and we are done iterating.
+ return nil, true, nil
+ }
+ id := IPv6ExtHdrOptionIndentifier(temp)
+
+ // If the option identifier indicates the option is a Pad1 option, then we
+ // know the option does not have Length and Data fields. End processing of
+ // the Pad1 option and continue processing the buffer as a new option.
+ if id == ipv6Pad1ExtHdrOptionIdentifier {
+ continue
+ }
+
+ length, err := i.reader.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
+ // we start parsing an option; we expect the reader to contain enough
+ // bytes for the whole option.
+ return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
+ }
+
+ // Special-case the variable length padding option to avoid a copy.
+ if id == ipv6PadNExtHdrOptionIdentifier {
+ // Do we have enough bytes in the reader for the PadN option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
+ panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
+ }
+
+ // End processing of the PadN option and continue processing the buffer as
+ // a new option.
+ continue
+ }
+
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ }
+
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
+ }
+}
+
+// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options
+// extension header.
+type IPv6HopByHopOptionsExtHdr struct {
+ ipv6OptionsExtHdr
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {}
+
+// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options
+// extension header.
+type IPv6DestinationOptionsExtHdr struct {
+ ipv6OptionsExtHdr
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
+
+// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
+// data as outlined in RFC 8200 section 4.4.
+type IPv6RoutingExtHdr []byte
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {}
+
+// SegmentsLeft returns the Segments Left field.
+func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 {
+ return b[ipv6RoutingExtHdrSegmentsLeftIdx]
+}
+
+// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
+// data as outlined in RFC 8200 section 4.5.
+//
+// Note, the buffer does not include the Next Header and Reserved fields.
+type IPv6FragmentExtHdr [6]byte
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {}
+
+// FragmentOffset returns the Fragment Offset field.
+//
+// This value indicates where the buffer following the Fragment extension header
+// starts in the target (reassembled) packet.
+func (b IPv6FragmentExtHdr) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift
+}
+
+// More returns the More (M) flag.
+//
+// This indicates whether any fragments are expected to succeed b.
+func (b IPv6FragmentExtHdr) More() bool {
+ return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0
+}
+
+// ID returns the Identification field.
+//
+// This value is used to uniquely identify the packet, between a
+// souce and destination.
+func (b IPv6FragmentExtHdr) ID() uint32 {
+ return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
+}
+
+// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
+//
+// The IPv6 payload may contain IPv6 extension headers before any upper layer
+// data.
+//
+// Note, between when an IPv6PayloadIterator is obtained and last used, no
+// changes to the payload may happen. Doing so may cause undefined and
+// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate
+// over the first few headers then modify the backing payload so long as the
+// IPv6PayloadIterator obtained before modification is no longer used.
+type IPv6PayloadIterator struct {
+ // The identifier of the next header to parse.
+ nextHdrIdentifier IPv6ExtensionHeaderIdentifier
+
+ // reader is an io.Reader over payload.
+ reader bufio.Reader
+ payload buffer.VectorisedView
+
+ // Indicates to the iterator that it should return the remaining payload as a
+ // raw payload on the next call to Next.
+ forceRaw bool
+}
+
+// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
+// extension headers, or a raw payload if the payload cannot be parsed.
+func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.VectorisedView) IPv6PayloadIterator {
+ readers := payload.Readers()
+ readerPs := make([]io.Reader, 0, len(readers))
+ for i := range readers {
+ readerPs = append(readerPs, &readers[i])
+ }
+
+ return IPv6PayloadIterator{
+ nextHdrIdentifier: nextHdrIdentifier,
+ payload: payload.Clone(nil),
+ // We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ }
+}
+
+// AsRawHeader returns the remaining payload of i as a raw header and
+// completes the iterator.
+//
+// Calls to Next after calling AsRawHeader on i will indicate that the
+// iterator is done.
+func (i *IPv6PayloadIterator) AsRawHeader() IPv6RawPayloadHeader {
+ buf := i.payload
+ identifier := i.nextHdrIdentifier
+
+ // Mark i as done.
+ *i = IPv6PayloadIterator{
+ nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ }
+
+ return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
+}
+
+// Next returns the next item in the payload.
+//
+// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader
+// will be returned with the remaining bytes and next header identifier.
+//
+// The return is of the format (header, done, error). done will be true when
+// Next is unable to return anything because the iterator has reached the end of
+// the payload, or an error occured.
+func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ // We could be forced to return i as a raw header when the previous header was
+ // a fragment extension header as the data following the fragment extension
+ // header may not be complete.
+ if i.forceRaw {
+ return i.AsRawHeader(), false, nil
+ }
+
+ // Is the header we are parsing a known extension header?
+ switch i.nextHdrIdentifier {
+ case IPv6HopByHopOptionsExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil
+ case IPv6RoutingExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6RoutingExtHdr(bytes), false, nil
+ case IPv6FragmentExtHdrIdentifier:
+ var data [6]byte
+ // We ignore the returned bytes becauase we know the fragment extension
+ // header specific data will fit in data.
+ nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
+ if err != nil {
+ return nil, true, err
+ }
+
+ fragmentExtHdr := IPv6FragmentExtHdr(data)
+
+ // If the packet is a fragmented packet, do not attempt to parse
+ // anything after the fragment extension header as the data following
+ // the extension header may not be complete.
+ if fragmentExtHdr.More() || fragmentExtHdr.FragmentOffset() != 0 {
+ i.forceRaw = true
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return fragmentExtHdr, false, nil
+ case IPv6DestinationOptionsExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil
+ case IPv6NoNextHeaderIdentifier:
+ // This indicates the end of the IPv6 payload.
+ return nil, true, nil
+
+ default:
+ // The header we are parsing is not a known extension header. Return the
+ // raw payload.
+ return i.AsRawHeader(), false, nil
+ }
+}
+
+// nextHeaderData returns the extension header's Next Header field and raw data.
+//
+// fragmentHdr indicates that the extension header being parsed is the Fragment
+// extension header so the Length field should be ignored as it is Reserved
+// for the Fragment extension header.
+//
+// If bytes is not nil, extension header specific data will be read into bytes
+// if it has enough capacity. If bytes is provided but does not have enough
+// capacity for the data, nextHeaderData will panic.
+func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, []byte, error) {
+ // We ignore the number of bytes read because we know we will only ever read
+ // at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
+ // would return io.EOF to indicate that io.Reader has reached the end of the
+ // payload.
+ nextHdrIdentifier, err := i.reader.ReadByte()
+ i.payload.TrimFront(1)
+ if err != nil {
+ return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+
+ var length uint8
+ length, err = i.reader.ReadByte()
+ i.payload.TrimFront(1)
+ if err != nil {
+ if fragmentHdr {
+ return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+
+ return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+ if fragmentHdr {
+ length = 0
+ }
+
+ bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
+ if bytes == nil {
+ bytes = make([]byte, bytesLen)
+ } else if n := len(bytes); n < bytesLen {
+ panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
+ }
+
+ n, err := io.ReadFull(&i.reader, bytes)
+ i.payload.TrimFront(n)
+ if err != nil {
+ return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
+ }
+
+ return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
new file mode 100644
index 000000000..133ccc8b6
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -0,0 +1,957 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold the same Identifier value and
+// contain the same bytes in Buf, even if the bytes are split across views
+// differently.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool {
+ return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView())
+}
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6HopByHopOptionsExtHdr) Equal(b IPv6HopByHopOptionsExtHdr) bool {
+ return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
+}
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool {
+ return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
+}
+
+func TestIPv6UnknownExtHdrOption(t *testing.T) {
+ tests := []struct {
+ name string
+ identifier IPv6ExtHdrOptionIndentifier
+ expectedUnknownAction IPv6OptionUnknownAction
+ }{
+ {
+ name: "Skip with zero LSBs",
+ identifier: 0,
+ expectedUnknownAction: IPv6OptionUnknownActionSkip,
+ },
+ {
+ name: "Discard with zero LSBs",
+ identifier: 64,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscard,
+ },
+ {
+ name: "Discard and ICMP with zero LSBs",
+ identifier: 128,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
+ },
+ {
+ name: "Discard and ICMP for non multicast destination with zero LSBs",
+ identifier: 192,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ },
+ {
+ name: "Skip with non-zero LSBs",
+ identifier: 63,
+ expectedUnknownAction: IPv6OptionUnknownActionSkip,
+ },
+ {
+ name: "Discard with non-zero LSBs",
+ identifier: 127,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscard,
+ },
+ {
+ name: "Discard and ICMP with non-zero LSBs",
+ identifier: 191,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
+ },
+ {
+ name: "Discard and ICMP for non multicast destination with non-zero LSBs",
+ identifier: 255,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opt := &IPv6UnknownExtHdrOption{Identifier: test.identifier, Data: []byte{1, 2, 3, 4}}
+ if a := opt.UnknownAction(); a != test.expectedUnknownAction {
+ t.Fatalf("got UnknownAction() = %d, want = %d", a, test.expectedUnknownAction)
+ }
+ })
+ }
+
+}
+
+func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ err error
+ }{
+ {
+ name: "Single unknown with zero length",
+ bytes: []byte{255, 0},
+ },
+ {
+ name: "Single unknown with non-zero length",
+ bytes: []byte{255, 3, 1, 2, 3},
+ },
+ {
+ name: "Two options",
+ bytes: []byte{
+ 255, 0,
+ 254, 1, 1,
+ },
+ },
+ {
+ name: "Three options",
+ bytes: []byte{
+ 255, 0,
+ 254, 1, 1,
+ 253, 4, 2, 3, 4, 5,
+ },
+ },
+ {
+ name: "Single unknown only identifier",
+ bytes: []byte{255},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Single unknown too small with length = 1",
+ bytes: []byte{255, 1},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Single unknown too small with length = 2",
+ bytes: []byte{255, 2, 1},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown only identifier",
+ bytes: []byte{
+ 255, 0,
+ 254,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown missing data",
+ bytes: []byte{
+ 255, 0,
+ 254, 1,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown too small",
+ bytes: []byte{
+ 255, 0,
+ 254, 2, 1,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "One Pad1",
+ bytes: []byte{0},
+ },
+ {
+ name: "Multiple Pad1",
+ bytes: []byte{0, 0, 0},
+ },
+ {
+ name: "Multiple PadN",
+ bytes: []byte{
+ // Pad3
+ 1, 1, 1,
+
+ // Pad5
+ 1, 3, 1, 2, 3,
+ },
+ },
+ {
+ name: "Pad5 too small middle of data buffer",
+ bytes: []byte{1, 3, 1, 2},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Pad5 no data",
+ bytes: []byte{1, 3},
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
+ for i := 0; ; i++ {
+ _, done, err := it.Next()
+ if err != nil {
+ // If we encountered a non-nil error while iterating, make sure it is
+ // is the same error as expectedErr.
+ if !errors.Is(err, expectedErr) {
+ t.Fatalf("got %d-th Next() = %v, want = %v", i, err, expectedErr)
+ }
+
+ return
+ }
+ if done {
+ // If we are done (without an error), make sure that we did not expect
+ // an error.
+ if expectedErr != nil {
+ t.Fatalf("expected error when iterating; want = %s", expectedErr)
+ }
+
+ return
+ }
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Run("Hop By Hop", func(t *testing.T) {
+ extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ check(t, extHdr.Iter(), test.err)
+ })
+
+ t.Run("Destination", func(t *testing.T) {
+ extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ check(t, extHdr.Iter(), test.err)
+ })
+ })
+ }
+}
+
+func TestIPv6OptionsExtHdrIter(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ expected []IPv6ExtHdrOption
+ }{
+ {
+ name: "Single unknown with zero length",
+ bytes: []byte{255, 0},
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
+ },
+ },
+ {
+ name: "Single unknown with non-zero length",
+ bytes: []byte{255, 3, 1, 2, 3},
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{1, 2, 3}},
+ },
+ },
+ {
+ name: "Single Pad1",
+ bytes: []byte{0},
+ },
+ {
+ name: "Two Pad1",
+ bytes: []byte{0, 0},
+ },
+ {
+ name: "Single Pad3",
+ bytes: []byte{1, 1, 1},
+ },
+ {
+ name: "Single Pad5",
+ bytes: []byte{1, 3, 1, 2, 3},
+ },
+ {
+ name: "Multiple Pad",
+ bytes: []byte{
+ // Pad1
+ 0,
+
+ // Pad2
+ 1, 0,
+
+ // Pad3
+ 1, 1, 1,
+
+ // Pad4
+ 1, 2, 1, 2,
+
+ // Pad5
+ 1, 3, 1, 2, 3,
+ },
+ },
+ {
+ name: "Multiple options",
+ bytes: []byte{
+ // Pad1
+ 0,
+
+ // Unknown
+ 255, 0,
+
+ // Pad2
+ 1, 0,
+
+ // Unknown
+ 254, 1, 1,
+
+ // Pad3
+ 1, 1, 1,
+
+ // Unknown
+ 253, 4, 2, 3, 4, 5,
+
+ // Pad4
+ 1, 2, 1, 2,
+ },
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
+ &IPv6UnknownExtHdrOption{Identifier: 254, Data: []byte{1}},
+ &IPv6UnknownExtHdrOption{Identifier: 253, Data: []byte{2, 3, 4, 5}},
+ },
+ },
+ }
+
+ checkIter := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expected []IPv6ExtHdrOption) {
+ for i, e := range expected {
+ opt, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(i=%d) Next(): %s", i, err)
+ }
+ if done {
+ t.Errorf("(i=%d) unexpectedly done iterating", i)
+ }
+ if diff := cmp.Diff(e, opt); diff != "" {
+ t.Errorf("(i=%d) got option mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ opt, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(last) Next(): %s", err)
+ }
+ if !done {
+ t.Errorf("(last) iterator unexpectedly not done")
+ }
+ if opt != nil {
+ t.Errorf("(last) got Next() = %T, want = nil", opt)
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Run("Hop By Hop", func(t *testing.T) {
+ extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ checkIter(t, extHdr.Iter(), test.expected)
+ })
+
+ t.Run("Destination", func(t *testing.T) {
+ extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ checkIter(t, extHdr.Iter(), test.expected)
+ })
+ })
+ }
+}
+
+func TestIPv6RoutingExtHdr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ segmentsLeft uint8
+ }{
+ {
+ name: "Zeroes",
+ bytes: []byte{0, 0, 0, 0, 0, 0},
+ segmentsLeft: 0,
+ },
+ {
+ name: "Ones",
+ bytes: []byte{1, 1, 1, 1, 1, 1},
+ segmentsLeft: 1,
+ },
+ {
+ name: "Mixed",
+ bytes: []byte{1, 2, 3, 4, 5, 6},
+ segmentsLeft: 2,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ extHdr := IPv6RoutingExtHdr(test.bytes)
+ if got := extHdr.SegmentsLeft(); got != test.segmentsLeft {
+ t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft)
+ }
+ })
+ }
+}
+
+func TestIPv6FragmentExtHdr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes [6]byte
+ fragmentOffset uint16
+ more bool
+ id uint32
+ }{
+ {
+ name: "Zeroes",
+ bytes: [6]byte{0, 0, 0, 0, 0, 0},
+ fragmentOffset: 0,
+ more: false,
+ id: 0,
+ },
+ {
+ name: "Ones",
+ bytes: [6]byte{0, 9, 0, 0, 0, 1},
+ fragmentOffset: 1,
+ more: true,
+ id: 1,
+ },
+ {
+ name: "Mixed",
+ bytes: [6]byte{68, 9, 128, 4, 2, 1},
+ fragmentOffset: 2177,
+ more: true,
+ id: 2147746305,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ extHdr := IPv6FragmentExtHdr(test.bytes)
+ if got := extHdr.FragmentOffset(); got != test.fragmentOffset {
+ t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset)
+ }
+ if got := extHdr.More(); got != test.more {
+ t.Errorf("got More() = %t, want = %t", got, test.more)
+ }
+ if got := extHdr.ID(); got != test.id {
+ t.Errorf("got ID() = %d, want = %d", got, test.id)
+ }
+ })
+ }
+}
+
+func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView {
+ size := 0
+ var vs []buffer.View
+
+ for _, b := range bs {
+ vs = append(vs, buffer.View(b))
+ size += len(b)
+ }
+
+ return buffer.NewVectorisedView(size, vs)
+}
+
+func TestIPv6ExtHdrIterErr(t *testing.T) {
+ tests := []struct {
+ name string
+ firstNextHdr IPv6ExtensionHeaderIdentifier
+ payload buffer.VectorisedView
+ err error
+ }{
+ {
+ name: "Upper layer only without data",
+ firstNextHdr: 255,
+ },
+ {
+ name: "Upper layer only with data",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
+ },
+ {
+ name: "No next header",
+ firstNextHdr: IPv6NoNextHeaderIdentifier,
+ },
+ {
+ name: "No next header with data",
+ firstNextHdr: IPv6NoNextHeaderIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
+ },
+ {
+ name: "Valid single hop by hop",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
+ },
+ {
+ name: "Hop by hop too small",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single fragment",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}),
+ },
+ {
+ name: "Fragment too small",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single destination",
+ firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
+ },
+ {
+ name: "Destination too small",
+ firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single routing",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}),
+ },
+ {
+ name: "Valid single routing across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}),
+ },
+ {
+ name: "Routing too small with zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid routing with non-zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}),
+ },
+ {
+ name: "Valid routing with non-zero length field across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}),
+ },
+ {
+ name: "Routing too small with non-zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Routing too small with non-zero length field across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Mixed",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3, 4,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ },
+ {
+ name: "Mixed without upper layer data",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3, 4,
+ }),
+ },
+ {
+ name: "Mixed without upper layer data but last ext hdr too small",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3,
+ }),
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
+
+ for i := 0; ; i++ {
+ _, done, err := it.Next()
+ if err != nil {
+ // If we encountered a non-nil error while iterating, make sure it is
+ // is the same error as test.err.
+ if !errors.Is(err, test.err) {
+ t.Fatalf("got %d-th Next() = %v, want = %v", i, err, test.err)
+ }
+
+ return
+ }
+ if done {
+ // If we are done (without an error), make sure that we did not expect
+ // an error.
+ if test.err != nil {
+ t.Fatalf("expected error when iterating; want = %s", test.err)
+ }
+
+ return
+ }
+ }
+ })
+ }
+}
+
+func TestIPv6ExtHdrIter(t *testing.T) {
+ routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4})
+ upperLayerData := buffer.View([]byte{1, 2, 3, 4})
+ tests := []struct {
+ name string
+ firstNextHdr IPv6ExtensionHeaderIdentifier
+ payload buffer.VectorisedView
+ expected []IPv6PayloadHeader
+ }{
+ // With a non-atomic fragment, the payload after the fragment will not be
+ // parsed because the payload may not be complete.
+ {
+ name: "hopbyhop - fragment - routing - upper",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Fragment extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2, 3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6RoutingExtHdrIdentifier,
+ Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "fragment - routing - upper (across views)",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2}, []byte{3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6RoutingExtHdrIdentifier,
+ Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+
+ // If we have an atomic fragment, the payload following the fragment
+ // extension header should be parsed normally.
+ {
+ name: "atomic fragment - routing - destination - upper",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 1, 4, 1, 2, 3, 4,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "atomic fragment - routing - upper (across views)",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2}, []byte{3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2}, []byte{3, 4}),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ },
+ },
+ },
+ {
+ name: "atomic fragment - destination - no next header",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Res (Reserved) bits are 1 which should not affect anything.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 0, 6, 128, 4, 2, 1,
+
+ // Destination Options extension header.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ },
+ },
+ {
+ name: "routing - atomic fragment - no next header",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ },
+ },
+ {
+ name: "routing - atomic fragment - no next header (across views)",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ },
+ },
+ {
+ name: "hopbyhop - routing - fragment - no next header",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Fragment Offset = 32; Res = 6.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6NoNextHeaderIdentifier,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+
+ // Test the raw payload for common transport layer protocol numbers.
+ {
+ name: "TCP raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "UDP raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "ICMPv4 raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "ICMPv6 raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "Unknwon next header raw payload",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "Unknwon next header raw payload (across views)",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ }},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
+
+ for i, e := range test.expected {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(i=%d) Next(): %s", i, err)
+ }
+ if done {
+ t.Errorf("(i=%d) unexpectedly done iterating", i)
+ }
+ if diff := cmp.Diff(e, extHdr); diff != "" {
+ t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ extHdr, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(last) Next(): %s", err)
+ }
+ if !done {
+ t.Errorf("(last) iterator unexpectedly not done")
+ }
+ if extHdr != nil {
+ t.Errorf("(last) got Next() = %T, want = nil", extHdr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
index 6a215938b..8f65713c5 100644
--- a/pkg/tcpip/network/hash/hash.go
+++ b/pkg/tcpip/network/hash/hash.go
@@ -80,12 +80,12 @@ func IPv4FragmentHash(h header.IPv4) uint32 {
// RFC 2640 (sec 4.5) is not very sharp on this aspect.
// As a reference, also Linux ignores the protocol to compute
// the hash (inet6_hash_frag).
-func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 {
+func IPv6FragmentHash(h header.IPv6, id uint32) uint32 {
t := h.SourceAddress()
y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = h.DestinationAddress()
z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- return Hash3Words(f.ID(), y, z, hashIV)
+ return Hash3Words(id, y, z, hashIV)
}
func rol32(v, shift uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b3ee6000e..a7d9a8b25 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -244,6 +244,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt); !ok {
+ // iptables is telling us to drop the packet.
+ return nil
+ }
+
if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
@@ -280,7 +288,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac
return len(pkts), nil
}
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
for i := range pkts {
+ if ok := ipt.Check(stack.Output, pkts[i]); !ok {
+ // iptables is telling us to drop the packet.
+ continue
+ }
ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params)
pkts[i].NetworkHeader = buffer.View(ip)
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index fb11874c6..a93a7621a 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -13,6 +13,8 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -36,5 +38,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 8640feffc..e0dd5afd3 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,7 +15,7 @@
package ipv6
import (
- "log"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -199,7 +199,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
opt, done, err := it.Next()
if err != nil {
// This should never happen as Iter(true) above did not return an error.
- log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
}
if done {
break
@@ -306,7 +306,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
opt, done, err := it.Next()
if err != nil {
// This should never happen as Iter(true) above did not return an error.
- log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
}
if done {
break
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 29e597002..685239017 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -21,11 +21,14 @@
package ipv6
import (
+ "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -49,6 +52,7 @@ type endpoint struct {
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
+ fragmentation *fragmentation.Fragmentation
protocol *protocol
}
@@ -172,6 +176,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
headerView := pkt.Data.First()
h := header.IPv6(headerView)
if !h.IsValid(pkt.Data.Size()) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
@@ -179,14 +184,184 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
pkt.Data.TrimFront(header.IPv6MinimumSize)
pkt.Data.CapLength(int(h.PayloadLength()))
- p := h.TransportProtocol()
- if p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, pkt)
- return
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data)
+
+ for firstHeader := true; ; firstHeader = false {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6HopByHopOptionsExtHdr:
+ // As per RFC 8200 section 4.1, the Hop By Hop extension header is
+ // restricted to appear immediately after an IPv6 fixed header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
+ // (unrecognized next header) error in response to an extension header's
+ // Next Header field with the Hop By Hop extension header identifier.
+ if !firstHeader {
+ return
+ }
+
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Hop By Hop extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RoutingExtHdr:
+ // As per RFC 8200 section 4.4, if a node encounters a routing header with
+ // an unrecognized routing type value, with a non-zero Segments Left
+ // value, the node must discard the packet and send an ICMP Parameter
+ // Problem, Code 0. If the Segments Left is 0, the node must ignore the
+ // Routing extension header and process the next header in the packet.
+ //
+ // Note, the stack does not yet handle any type of routing extension
+ // header, so we just make sure Segments Left is zero before processing
+ // the next extension header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
+ // unrecognized routing types with a non-zero Segments Left value.
+ if extHdr.SegmentsLeft() != 0 {
+ return
+ }
+
+ case header.IPv6FragmentExtHdr:
+ fragmentOffset := extHdr.FragmentOffset()
+ more := extHdr.More()
+ if !more && fragmentOffset == 0 {
+ // This fragment extension header indicates that this packet is an
+ // atomic fragment. An atomic fragment is a fragment that contains
+ // all the data required to reassemble a full packet. As per RFC 6946,
+ // atomic fragments must not interfere with "normal" fragmented traffic
+ // so we skip processing the fragment instead of feeding it through the
+ // reassembly process below.
+ continue
+ }
+
+ rawPayload := it.AsRawHeader()
+ fragmentPayloadLen := rawPayload.Buf.Size()
+ if fragmentPayloadLen == 0 {
+ // Drop the packet as it's marked as a fragment but has no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ // The packet is a fragment, let's try to reassemble it.
+ start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
+ last := start + uint16(fragmentPayloadLen) - 1
+
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
+ if last < start {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ var ready bool
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ if ready {
+ // We create a new iterator with the reassembled packet because we could
+ // have more extension headers in the reassembled payload, as per RFC
+ // 8200 section 4.5.
+ it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data)
+ }
+
+ case header.IPv6DestinationOptionsExtHdr:
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Destination extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RawPayloadHeader:
+ // If the last header in the payload isn't a known IPv6 extension header,
+ // handle it as if it is transport layer data.
+ pkt.Data = extHdr.Buf
+
+ if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, headerView, pkt)
+ } else {
+ r.Stats().IP.PacketsDelivered.Increment()
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ }
+
+ default:
+ // If we receive a packet for an extension header we do not yet handle,
+ // drop the packet for now.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ r.Stats().UnknownProtocolRcvdPackets.Increment()
+ return
+ }
}
-
- r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, pkt)
}
// Close cleans up resources associated with the endpoint.
@@ -229,6 +404,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
protocol: p,
}, nil
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index ed98ef22a..37f7e53ce 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -17,6 +17,7 @@ package ipv6
import (
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -33,6 +34,14 @@ const (
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+
+ // Tests use the extension header identifier values as uint8 instead of
+ // header.IPv6ExtensionHeaderIdentifier.
+ hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier)
+ routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier)
+ fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
+ destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
+ noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
)
// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
@@ -268,3 +277,975 @@ func TestAddIpv6Address(t *testing.T) {
})
}
}
+
+func TestReceiveIPv6ExtHdrs(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ extHdr func(nextHdr uint8) ([]byte, uint8)
+ shouldAccept bool
+ }{
+ {
+ name: "None",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "destination with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing - atomic fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ nextHdr, 0, 0, 0, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Fragment extension header.
+ routingExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, fragmentExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hop by hop (with skippable unknown) - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "routing - hop by hop (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Hop By Hop extension header with skippable unknown option.
+ nextHdr, 0, 62, 4, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with discard action for unknown option.
+ routingExtHdrID, 0, 65, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with discard action for unknown
+ // option.
+ nextHdr, 0, 65, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8}
+ udpLength := header.UDPMinimumSize + len(udpPayload)
+ extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber))
+ extHdrLen := len(extHdrBytes)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength)
+
+ // Serialize UDP message.
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), udpPayload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(udpPayload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ // Copy extension header bytes between the UDP message and the IPv6
+ // fixed header.
+ copy(hdr.Prepend(extHdrLen), extHdrBytes)
+
+ // Serialize IPv6 fixed header.
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: ipv6NextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stats := s.Stats().UDP.PacketsReceived
+
+ if !test.shouldAccept {
+ if got := stats.Value(); got != 0 {
+ t.Errorf("got UDP Rx Packets = %d, want = 0", got)
+ }
+
+ return
+ }
+
+ // Expect a UDP packet.
+ if got := stats.Value(); got != 1 {
+ t.Errorf("got UDP Rx Packets = %d, want = 1", got)
+ }
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have any more UDP packets.
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
+
+// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
+type fragmentData struct {
+ nextHdr uint8
+ data buffer.VectorisedView
+}
+
+func TestReceiveIPv6Fragments(t *testing.T) {
+ const nicID = 1
+ const udpPayload1Length = 256
+ const udpPayload2Length = 128
+ const fragmentExtHdrLen = 8
+ // Note, not all routing extension headers will be 8 bytes but this test
+ // uses 8 byte routing extension headers for most sub tests.
+ const routingExtHdrLen = 8
+
+ udpGen := func(payload []byte, multiplier uint8) buffer.View {
+ payloadLen := len(payload)
+ for i := 0; i < payloadLen; i++ {
+ payload[i] = uint8(i) * multiplier
+ }
+
+ udpLength := header.UDPMinimumSize + payloadLen
+
+ hdr := buffer.NewPrependable(udpLength)
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), payload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(payload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ return hdr.View()
+ }
+
+ var udpPayload1Buf [udpPayload1Length]byte
+ udpPayload1 := udpPayload1Buf[:]
+ ipv6Payload1 := udpGen(udpPayload1, 1)
+
+ var udpPayload2Buf [udpPayload2Length]byte
+ udpPayload2 := udpPayload2Buf[:]
+ ipv6Payload2 := udpGen(udpPayload2, 2)
+
+ tests := []struct {
+ name string
+ expectedPayload []byte
+ fragments []fragmentData
+ expectedPayloads [][]byte
+ }{
+ {
+ name: "No fragmentation",
+ fragments: []fragmentData{
+ {
+ nextHdr: uint8(header.UDPProtocolNumber),
+ data: ipv6Payload1.ToVectorisedView(),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Atomic fragment",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with different IDs",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with per-fragment routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with per-fragment routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal"
+ // fragmented traffic.
+ {
+ name: "Two fragments with atomic",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ // This fragment has the same ID as the other fragments but is an atomic
+ // fragment. It should not interfere with the other fragments.
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
+
+ ipv6Payload2,
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload2, udpPayload1},
+ },
+ {
+ name: "Two interleaved fragmented packets",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
+
+ ipv6Payload2[:32],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
+
+ ipv6Payload2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: vv,
+ })
+ }
+
+ if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
+ t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
+ }
+
+ for i, p := range test.expectedPayloads {
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ }
+ if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" {
+ t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 630fdefc5..7c9fc48d1 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"log"
"math/rand"
"time"
@@ -428,7 +429,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
if ref.getKind() != permanentTentative {
// The endpoint should be marked as tentative since we are starting DAD.
- log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
}
// Should not attempt to perform DAD on an address that is currently in the
@@ -440,7 +441,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
// See NIC.addAddressLocked.
- log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
@@ -476,7 +477,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
if ref.getKind() != permanentTentative {
// The endpoint should still be marked as tentative since we are still
// performing DAD on it.
- log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()))
}
dadDone := remaining == 0
@@ -546,9 +547,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
- log.Fatalf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
} else if c != nil {
- log.Fatalf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
}
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
@@ -949,7 +950,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
- log.Fatalf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix)
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix))
}
ndp.deprecateSLAACAddress(prefixState.ref)
@@ -1029,7 +1030,7 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen
ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
if err != nil {
- log.Fatalf("ndp: error when adding address %+v: %s", generatedAddr, err)
+ panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err))
}
return ref
@@ -1043,7 +1044,7 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen
func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
- log.Fatalf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix)
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix))
}
defer func() { ndp.slaacPrefixes[prefix] = prefixState }()
@@ -1144,7 +1145,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool)
if removeAddr {
if err := ndp.nic.removePermanentAddressLocked(addr); err != nil {
- log.Fatalf("ndp: removePermanentAddressLocked(%s): %s", addr, err)
+ panic(fmt.Sprintf("ndp: removePermanentAddressLocked(%s): %s", addr, err))
}
}
@@ -1193,7 +1194,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
}
if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
- log.Fatalf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)
+ panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
}
for prefix := range ndp.onLinkPrefixes {
@@ -1201,7 +1202,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
}
if got := len(ndp.onLinkPrefixes); got != 0 {
- log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)
+ panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
}
for router := range ndp.defaultRouters {
@@ -1209,7 +1210,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
}
if got := len(ndp.defaultRouters); got != 0 {
- log.Fatalf("ndp: still have discovered default routers after cleaning up; found = %d", got)
+ panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
}
}
@@ -1251,9 +1252,9 @@ func (ndp *ndpState) startSolicitingRouters() {
// header.IPv6AllRoutersMulticastAddress is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
- log.Fatalf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
} else if c != nil {
- log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index b6fa647ea..4835251bc 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -16,7 +16,6 @@ package stack
import (
"fmt"
- "log"
"reflect"
"sort"
"strings"
@@ -480,7 +479,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
// Should never happen as we got r from the primary IPv6 endpoint list and
// ScopeForIPv6Address only returns an error if addr is not an IPv6
// address.
- log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err)
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
}
cs = append(cs, ipv6AddrCandidate{
@@ -492,7 +491,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
if err != nil {
// primaryIPv6Endpoint should never be called with an invalid IPv6 address.
- log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
}
// Sort the addresses as per RFC 6724 section 5 rules 1-3.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 9505a4e92..9367de180 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -13,7 +13,10 @@
package stack
-import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
// A PacketBuffer contains all the data of a network packet.
//
@@ -59,6 +62,10 @@ type PacketBuffer struct {
// Hash is the transport layer hash of this packet. A value of zero
// indicates no valid hash has been set.
Hash uint32
+
+ // Owner is implemented by task to get the uid and gid.
+ // Only set for locally generated packets.
+ Owner tcpip.PacketOwner
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 75c119c99..c65b0c632 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -31,12 +31,14 @@ import (
)
const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testPort = 4096
+ testSrcAddrV4 = "\x0a\x00\x00\x01"
+ testDstAddrV4 = "\x0a\x00\x00\x02"
+
+ testDstPort = 1234
+ testSrcPort = 4096
)
type testContext struct {
@@ -59,11 +61,11 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
t.Fatalf("AddAddress IPv4 failed: %s", err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
t.Fatalf("AddAddress IPv6 failed: %s", err)
}
}
@@ -91,6 +93,47 @@ func newPayload() []byte {
return b
}
+func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: 0x80,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: testSrcAddrV4,
+ DstAddr: testDstAddrV4,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
+}
+
func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
@@ -102,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -115,7 +158,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -123,7 +166,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Inject packet.
c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
})
}
@@ -227,9 +272,12 @@ func TestBindToDeviceDistribution(t *testing.T) {
},
},
} {
- t.Run(test.name, func(t *testing.T) {
+ for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
+ "IPv4": ipv4.ProtocolNumber,
+ "IPv6": ipv6.ProtocolNumber,
+ } {
for device, wantDistribution := range test.wantDistributions {
- t.Run(string(device), func(t *testing.T) {
+ t.Run(test.name+protoName+string(device), func(t *testing.T) {
var devices []tcpip.NICID
for d := range test.wantDistributions {
devices = append(devices, d)
@@ -248,7 +296,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
defer close(ch)
var err *tcpip.Error
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
@@ -269,7 +317,17 @@ func TestBindToDeviceDistribution(t *testing.T) {
if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
}
- if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
+
+ var dstAddr tcpip.Address
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ dstAddr = testDstAddrV4
+ case ipv6.ProtocolNumber:
+ dstAddr = testDstAddrV6
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
}
}
@@ -285,11 +343,18 @@ func TestBindToDeviceDistribution(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
- c.sendV6Packet(payload,
- &headers{
- srcPort: testPort + port,
- dstPort: stackPort},
- device)
+ hdrs := &headers{
+ srcPort: testSrcPort + port,
+ dstPort: testDstPort,
+ }
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ c.sendV4Packet(payload, hdrs, device)
+ case ipv6.ProtocolNumber:
+ c.sendV6Packet(payload, hdrs, device)
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ }
ep := <-pollChannel
if _, _, err := ep.Read(nil); err != nil {
@@ -320,6 +385,6 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
})
}
- })
+ }
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 8ca9ac3cf..3084e6593 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -56,6 +56,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
+func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+
func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 3dc5d87d6..2ef3271f1 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -336,6 +336,15 @@ type ControlMessages struct {
PacketInfo IPPacketInfo
}
+// PacketOwner is used to get UID and GID of the packet.
+type PacketOwner interface {
+ // UID returns UID of the packet.
+ UID() uint32
+
+ // GID returns GID of the packet.
+ GID() uint32
+}
+
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
@@ -470,6 +479,9 @@ type Endpoint interface {
// Stats returns a reference to the endpoint stats.
Stats() EndpointStats
+
+ // SetOwner sets the task owner to the endpoint owner.
+ SetOwner(owner PacketOwner)
}
// EndpointInfo is the interface implemented by each endpoint info struct.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 613b12ead..b007302fb 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -73,6 +73,9 @@ type endpoint struct {
route stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
@@ -133,6 +136,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
// IPTables implements tcpip.Endpoint.IPTables.
func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
@@ -321,7 +328,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
switch e.NetProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.ID.LocalPort, v, e.ttl)
+ err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
case header.IPv6ProtocolNumber:
err = send6(route, e.ID.LocalPort, v, e.ttl)
@@ -415,7 +422,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
+func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return tcpip.ErrInvalidEndpointState
}
@@ -444,6 +451,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
Header: hdr,
Data: data.ToVectorisedView(),
TransportHeader: buffer.View(icmpv4),
+ Owner: owner,
})
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index df49d0995..23158173d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -392,3 +392,5 @@ func (ep *endpoint) Info() tcpip.EndpointInfo {
func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+
+func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 536dafd1e..337bc1c71 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -80,6 +80,9 @@ type endpoint struct {
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -159,6 +162,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
// IPTables implements tcpip.Endpoint.IPTables.
func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
@@ -348,10 +355,12 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
break
}
+
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: buffer.View(payloadBytes).ToVectorisedView(),
+ Owner: e.owner,
}); err != nil {
return 0, nil, err
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 375ca21f6..7a9dea4ac 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -276,7 +276,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// and then performs the TCP 3-way handshake.
//
// The new endpoint is returned with e.mu held.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
@@ -284,6 +284,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if err != nil {
return nil, err
}
+ ep.owner = owner
// listenEP is nil when listenContext is used by tcp.Forwarder.
deferAccept := time.Duration(0)
@@ -414,7 +415,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}()
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 1d245c2c6..3239a5911 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -745,7 +745,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp
func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
tf.txHash = e.txHash
- if err := sendTCP(r, tf, data, gso); err != nil {
+ if err := sendTCP(r, tf, data, gso, e.owner); err != nil {
e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
return err
}
@@ -787,7 +787,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
}
}
-func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
optLen := len(tf.opts)
if tf.rcvWnd > 0xffff {
tf.rcvWnd = 0xffff
@@ -816,6 +816,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
pkts[i].DataSize = packetSize
pkts[i].Data = data
pkts[i].Hash = tf.txHash
+ pkts[i].Owner = owner
buildTCPHdr(r, tf, &pkts[i], gso)
off += packetSize
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
@@ -833,14 +834,14 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
optLen := len(tf.opts)
if tf.rcvWnd > 0xffff {
tf.rcvWnd = 0xffff
}
if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
- return sendTCPBatch(r, tf, data, gso)
+ return sendTCPBatch(r, tf, data, gso, owner)
}
pkt := stack.PacketBuffer{
@@ -849,6 +850,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
DataSize: data.Size(),
Data: data,
Hash: tf.txHash,
+ Owner: owner,
}
buildTCPHdr(r, tf, &pkt, gso)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 1ebee0cfe..9b123e968 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -603,6 +603,9 @@ type endpoint struct {
// txHash is the transport layer hash to be set on outbound packets
// emitted by this endpoint.
txHash uint32
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -1132,6 +1135,10 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
// IPTables implements tcpip.Endpoint.IPTables.
func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index a094471b8..808410c92 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -157,7 +157,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
TSVal: r.synOptions.TSVal,
TSEcr: r.synOptions.TSEcr,
SACKPermitted: r.synOptions.SACKPermitted,
- }, queue)
+ }, queue, nil)
if err != nil {
return nil, err
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 1377107ca..dce9a1652 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -199,7 +199,7 @@ func replyWithReset(s *segment) {
seq: seq,
ack: ack,
rcvWnd: 0,
- }, buffer.VectorisedView{}, nil /* gso */)
+ }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */)
}
// SetOption implements stack.TransportProtocol.SetOption.
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index a3372ac58..120d3baa3 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -143,6 +143,9 @@ type endpoint struct {
// TODO(b/142022063): Add ability to save and restore per endpoint stats.
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
// +stateify savable
@@ -484,7 +487,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -886,7 +889,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -916,6 +919,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
Header: hdr,
Data: data,
TransportHeader: buffer.View(udp),
+ Owner: owner,
}); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
@@ -1356,3 +1360,7 @@ func (*endpoint) Wait() {}
func isBroadcastOrMulticast(a tcpip.Address) bool {
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
}
+
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
diff --git a/runsc/boot/fds.go b/runsc/boot/fds.go
index 417d2d5fb..5314b0f2a 100644
--- a/runsc/boot/fds.go
+++ b/runsc/boot/fds.go
@@ -34,7 +34,6 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
k := kernel.KernelFromContext(ctx)
fdTable := k.NewFDTable()
defer fdTable.DecRef()
- mounter := fs.FileOwnerFromContext(ctx)
var ttyFile *fs.File
for appFD, hostFD := range stdioFDs {
@@ -44,7 +43,7 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
// Import the file as a host TTY file.
if ttyFile == nil {
var err error
- appFile, err = host.ImportFile(ctx, hostFD, mounter, true /* isTTY */)
+ appFile, err = host.ImportFile(ctx, hostFD, true /* isTTY */)
if err != nil {
return nil, err
}
@@ -63,7 +62,7 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
} else {
// Import the file as a regular host file.
var err error
- appFile, err = host.ImportFile(ctx, hostFD, mounter, false /* isTTY */)
+ appFile, err = host.ImportFile(ctx, hostFD, false /* isTTY */)
if err != nil {
return nil, err
}
diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh
index b4a5211a5..0f46909ac 100755
--- a/scripts/iptables_tests.sh
+++ b/scripts/iptables_tests.sh
@@ -16,7 +16,7 @@
source $(dirname $0)/common.sh
-install_runsc_for_test iptables
+install_runsc_for_test iptables --net-raw
# Build the docker image for the test.
run //test/iptables/runner:runner-image --norun
@@ -26,5 +26,5 @@ test //test/iptables:iptables_test \
"--test_arg=--image=bazel/test/iptables/runner:runner-image"
test //test/iptables:iptables_test \
- "--test_arg=--runtime=runsc" \
+ "--test_arg=--runtime=${RUNTIME}" \
"--test_arg=--image=bazel/test/iptables/runner:runner-image"
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index 4ccd4cce7..41e0cfa8d 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -194,14 +194,11 @@ func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error {
- // After the container sets its DROP rule, we shouldn't be able to connect.
- // However, we may succeed in connecting if this runs before the container
- // sets the rule. To avoid this race, we retry connecting until
- // sendloopDuration has elapsed, ignoring whether the connect succeeds. The
- // test works becuase the container will error if a connection is
- // established after the rule is set.
+ // Ensure we cannot connect to the container.
for start := time.Now(); time.Since(start) < sendloopDuration; {
- connectTCP(ip, dropPort, sendloopDuration-time.Since(start))
+ if err := connectTCP(ip, dropPort, sendloopDuration-time.Since(start)); err == nil {
+ return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort)
+ }
}
return nil
@@ -232,14 +229,11 @@ func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error {
- // After the container sets its DROP rule, we shouldn't be able to connect.
- // However, we may succeed in connecting if this runs before the container
- // sets the rule. To avoid this race, we retry connecting until
- // sendloopDuration has elapsed, ignoring whether the connect succeeds. The
- // test works becuase the container will error if a connection is
- // established after the rule is set.
+ // Ensure we cannot connect to the container.
for start := time.Now(); time.Since(start) < sendloopDuration; {
- connectTCP(ip, acceptPort, sendloopDuration-time.Since(start))
+ if err := connectTCP(ip, acceptPort, sendloopDuration-time.Since(start)); err == nil {
+ return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort)
+ }
}
return nil
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index 4582d514c..f6d974b85 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -24,6 +24,11 @@ func init() {
RegisterTestCase(FilterOutputDropTCPSrcPort{})
RegisterTestCase(FilterOutputDestination{})
RegisterTestCase(FilterOutputInvertDestination{})
+ RegisterTestCase(FilterOutputAcceptTCPOwner{})
+ RegisterTestCase(FilterOutputDropTCPOwner{})
+ RegisterTestCase(FilterOutputAcceptUDPOwner{})
+ RegisterTestCase(FilterOutputDropUDPOwner{})
+ RegisterTestCase(FilterOutputOwnerFail{})
}
// FilterOutputDropTCPDestPort tests that connections are not accepted on
@@ -90,6 +95,144 @@ func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error {
return nil
}
+// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted.
+type FilterOutputAcceptTCPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptTCPOwner) Name() string {
+ return "FilterOutputAcceptTCPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("connection on port %d should be accepted, but got dropped", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("connection destined to port %d should be accepted, but got dropped", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped.
+type FilterOutputDropTCPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPOwner) Name() string {
+ return "FilterOutputDropTCPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted.
+type FilterOutputAcceptUDPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptUDPOwner) Name() string {
+ return "FilterOutputAcceptUDPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Send UDP packets on acceptPort.
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error {
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped.
+type FilterOutputDropUDPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropUDPOwner) Name() string {
+ return "FilterOutputDropUDPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Send UDP packets on dropPort.
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error {
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets should not be received")
+ }
+
+ return nil
+}
+
+// FilterOutputOwnerFail tests that without uid/gid option, owner rule
+// will fail.
+type FilterOutputOwnerFail struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputOwnerFail) Name() string {
+ return "FilterOutputOwnerFail"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil {
+ return fmt.Errorf("Invalid argument")
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputOwnerFail) LocalAction(ip net.IP) error {
+ // no-op.
+ return nil
+}
+
// FilterOutputDestination tests that we can selectively allow packets to
// certain destinations.
type FilterOutputDestination struct{}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 7f1f70606..493d69052 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -274,6 +274,36 @@ func TestFilterOutputDropTCPSrcPort(t *testing.T) {
}
}
+func TestFilterOutputAcceptTCPOwner(t *testing.T) {
+ if err := singleTest(FilterOutputAcceptTCPOwner{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputDropTCPOwner(t *testing.T) {
+ if err := singleTest(FilterOutputDropTCPOwner{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputAcceptUDPOwner(t *testing.T) {
+ if err := singleTest(FilterOutputAcceptUDPOwner{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputDropUDPOwner(t *testing.T) {
+ if err := singleTest(FilterOutputDropUDPOwner{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputOwnerFail(t *testing.T) {
+ if err := singleTest(FilterOutputOwnerFail{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestJumpSerialize(t *testing.T) {
if err := singleTest(FilterInputSerializeJump{}); err != nil {
t.Fatal(err)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 636e5db12..d0c431234 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -3336,10 +3336,7 @@ cc_binary(
cc_binary(
name = "sysret_test",
testonly = 1,
- srcs = select_arch(
- amd64 = ["sysret.cc"],
- arm64 = [],
- ),
+ srcs = ["sysret.cc"],
linkstatic = 1,
deps = [
gtest,
diff --git a/test/syscalls/linux/sysret.cc b/test/syscalls/linux/sysret.cc
index 819fa655a..19ffbd85b 100644
--- a/test/syscalls/linux/sysret.cc
+++ b/test/syscalls/linux/sysret.cc
@@ -14,6 +14,8 @@
// Tests to verify that the behavior of linux and gvisor matches when
// 'sysret' returns to bad (aka non-canonical) %rip or %rsp.
+
+#include <linux/elf.h>
#include <sys/ptrace.h>
#include <sys/user.h>
@@ -32,6 +34,7 @@ constexpr uint64_t kNonCanonicalRsp = 0xFFFF000000000000;
class SysretTest : public ::testing::Test {
protected:
struct user_regs_struct regs_;
+ struct iovec iov;
pid_t child_;
void SetUp() override {
@@ -48,10 +51,15 @@ class SysretTest : public ::testing::Test {
// Parent.
int status;
+ memset(&iov, 0, sizeof(iov));
ASSERT_THAT(pid, SyscallSucceeds()); // Might still be < 0.
ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
- ASSERT_THAT(ptrace(PTRACE_GETREGS, pid, 0, &regs_), SyscallSucceeds());
+
+ iov.iov_base = &regs_;
+ iov.iov_len = sizeof(regs_);
+ ASSERT_THAT(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
child_ = pid;
}
@@ -61,13 +69,27 @@ class SysretTest : public ::testing::Test {
}
void SetRip(uint64_t newrip) {
+#if defined(__x86_64__)
regs_.rip = newrip;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, &regs_), SyscallSucceeds());
+#elif defined(__aarch64__)
+ regs_.pc = newrip;
+#else
+#error "Unknown architecture"
+#endif
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
}
void SetRsp(uint64_t newrsp) {
+#if defined(__x86_64__)
regs_.rsp = newrsp;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, &regs_), SyscallSucceeds());
+#elif defined(__aarch64__)
+ regs_.sp = newrsp;
+#else
+#error "Unknown architecture"
+#endif
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
}
// Wait waits for the child pid and returns the exit status.
@@ -104,8 +126,15 @@ TEST_F(SysretTest, BadRsp) {
SetRsp(kNonCanonicalRsp);
Detach();
int status = Wait();
+#if defined(__x86_64__)
EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGBUS)
<< "status = " << status;
+#elif defined(__aarch64__)
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV)
+ << "status = " << status;
+#else
+#error "Unknown architecture"
+#endif
}
} // namespace