summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/fs/ext/BUILD22
-rw-r--r--pkg/sentry/fs/ext/dentry.go33
-rw-r--r--pkg/sentry/fs/ext/disklayout/dirent_old.go3
-rw-r--r--pkg/sentry/fs/ext/disklayout/inode.go6
-rw-r--r--pkg/sentry/fs/ext/disklayout/superblock_32.go3
-rw-r--r--pkg/sentry/fs/ext/disklayout/superblock_64.go3
-rw-r--r--pkg/sentry/fs/ext/disklayout/superblock_old.go2
-rw-r--r--pkg/sentry/fs/ext/ext.go102
-rw-r--r--pkg/sentry/fs/ext/ext_test.go407
-rw-r--r--pkg/sentry/fs/ext/filesystem.go137
-rw-r--r--pkg/sentry/fs/inode_overlay.go6
-rw-r--r--pkg/sentry/kernel/kernel.go5
-rw-r--r--pkg/sentry/kernel/threads.go12
-rw-r--r--pkg/sentry/socket/epsocket/stack.go4
-rw-r--r--pkg/tcpip/network/arp/arp.go11
-rw-r--r--pkg/tcpip/network/ip_test.go36
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go16
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go16
-rw-r--r--pkg/tcpip/stack/nic.go66
-rw-r--r--pkg/tcpip/stack/registration.go8
-rw-r--r--pkg/tcpip/stack/stack.go23
-rw-r--r--pkg/tcpip/stack/stack_test.go376
-rw-r--r--pkg/tcpip/tcpip.go13
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go4
25 files changed, 1064 insertions, 256 deletions
diff --git a/pkg/sentry/fs/ext/BUILD b/pkg/sentry/fs/ext/BUILD
index 3ba278e08..2c15875f5 100644
--- a/pkg/sentry/fs/ext/BUILD
+++ b/pkg/sentry/fs/ext/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"dentry.go",
"ext.go",
+ "filesystem.go",
"inode.go",
"utils.go",
],
@@ -15,7 +16,10 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/sentry/context",
"//pkg/sentry/fs/ext/disklayout",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
"//pkg/syserror",
],
)
@@ -23,11 +27,27 @@ go_library(
go_test(
name = "ext_test",
size = "small",
- srcs = ["extent_test.go"],
+ srcs = [
+ "ext_test.go",
+ "extent_test.go",
+ ],
+ data = [
+ "//pkg/sentry/fs/ext:assets/bigfile.txt",
+ "//pkg/sentry/fs/ext:assets/file.txt",
+ "//pkg/sentry/fs/ext:assets/tiny.ext2",
+ "//pkg/sentry/fs/ext:assets/tiny.ext3",
+ "//pkg/sentry/fs/ext:assets/tiny.ext4",
+ ],
embed = [":ext"],
deps = [
+ "//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/sentry/context",
+ "//pkg/sentry/context/contexttest",
"//pkg/sentry/fs/ext/disklayout",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//runsc/test/testutil",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/sentry/fs/ext/dentry.go b/pkg/sentry/fs/ext/dentry.go
index 71cd217df..054fb42b6 100644
--- a/pkg/sentry/fs/ext/dentry.go
+++ b/pkg/sentry/fs/ext/dentry.go
@@ -14,10 +14,43 @@
package ext
+import (
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
// dentry implements vfs.DentryImpl.
type dentry struct {
+ vfsd vfs.Dentry
+
// inode is the inode represented by this dentry. Multiple Dentries may
// share a single non-directory Inode (with hard links). inode is
// immutable.
inode *inode
}
+
+// Compiles only if dentry implements vfs.DentryImpl.
+var _ vfs.DentryImpl = (*dentry)(nil)
+
+// newDentry is the dentry constructor.
+func newDentry(in *inode) *dentry {
+ d := &dentry{
+ inode: in,
+ }
+ d.vfsd.Init(d)
+ return d
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef(vfsfs *vfs.Filesystem) {
+ d.inode.incRef()
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool {
+ return d.inode.tryIncRef()
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef(vfsfs *vfs.Filesystem) {
+ d.inode.decRef(vfsfs.Impl().(*filesystem))
+}
diff --git a/pkg/sentry/fs/ext/disklayout/dirent_old.go b/pkg/sentry/fs/ext/disklayout/dirent_old.go
index 2e0f9c812..6fff12a6e 100644
--- a/pkg/sentry/fs/ext/disklayout/dirent_old.go
+++ b/pkg/sentry/fs/ext/disklayout/dirent_old.go
@@ -17,8 +17,7 @@ package disklayout
import "gvisor.dev/gvisor/pkg/sentry/fs"
// DirentOld represents the old directory entry struct which does not contain
-// the file type. This emulates Linux's ext4_dir_entry struct. This is used in
-// ext2, ext3 and sometimes in ext4.
+// the file type. This emulates Linux's ext4_dir_entry struct.
//
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
diff --git a/pkg/sentry/fs/ext/disklayout/inode.go b/pkg/sentry/fs/ext/disklayout/inode.go
index 9ab9a4988..88ae913f5 100644
--- a/pkg/sentry/fs/ext/disklayout/inode.go
+++ b/pkg/sentry/fs/ext/disklayout/inode.go
@@ -20,6 +20,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
+// Special inodes. See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#special-inodes.
+const (
+ // RootDirInode is the inode number of the root directory inode.
+ RootDirInode = 2
+)
+
// The Inode interface must be implemented by structs representing ext inodes.
// The inode stores all the metadata pertaining to the file (except for the
// file name which is held by the directory entry). It does NOT expose all
diff --git a/pkg/sentry/fs/ext/disklayout/superblock_32.go b/pkg/sentry/fs/ext/disklayout/superblock_32.go
index 587e4afaa..53e515fd3 100644
--- a/pkg/sentry/fs/ext/disklayout/superblock_32.go
+++ b/pkg/sentry/fs/ext/disklayout/superblock_32.go
@@ -15,7 +15,8 @@
package disklayout
// SuperBlock32Bit implements SuperBlock and represents the 32-bit version of
-// the ext4_super_block struct in fs/ext4/ext4.h.
+// the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if
+// RevLevel = DynamicRev and 64-bit feature is disabled.
type SuperBlock32Bit struct {
// We embed the old superblock struct here because the 32-bit version is just
// an extension of the old version.
diff --git a/pkg/sentry/fs/ext/disklayout/superblock_64.go b/pkg/sentry/fs/ext/disklayout/superblock_64.go
index a2c2278fb..7c1053fb4 100644
--- a/pkg/sentry/fs/ext/disklayout/superblock_64.go
+++ b/pkg/sentry/fs/ext/disklayout/superblock_64.go
@@ -17,7 +17,8 @@ package disklayout
// SuperBlock64Bit implements SuperBlock and represents the 64-bit version of
// the ext4_super_block struct in fs/ext4/ext4.h. This sums up to be exactly
// 1024 bytes (smallest possible block size) and hence the superblock always
-// fits in no more than one data block.
+// fits in no more than one data block. Should only be used when the 64-bit
+// feature is set.
type SuperBlock64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fs/ext/disklayout/superblock_old.go b/pkg/sentry/fs/ext/disklayout/superblock_old.go
index 5a64aaaa1..9221e0251 100644
--- a/pkg/sentry/fs/ext/disklayout/superblock_old.go
+++ b/pkg/sentry/fs/ext/disklayout/superblock_old.go
@@ -15,7 +15,7 @@
package disklayout
// SuperBlockOld implements SuperBlock and represents the old version of the
-// superblock struct in ext2 and ext3 systems.
+// superblock struct. Should be used only if RevLevel = OldRev.
type SuperBlockOld struct {
InodesCountRaw uint32
BlocksCountLo uint32
diff --git a/pkg/sentry/fs/ext/ext.go b/pkg/sentry/fs/ext/ext.go
index 7f4287b01..10e235fb1 100644
--- a/pkg/sentry/fs/ext/ext.go
+++ b/pkg/sentry/fs/ext/ext.go
@@ -16,86 +16,82 @@
package ext
import (
+ "errors"
+ "fmt"
"io"
- "sync"
+ "os"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
-// filesystem implements vfs.FilesystemImpl.
-type filesystem struct {
- // mu serializes changes to the Dentry tree and the usage of the read seeker.
- mu sync.Mutex
-
- // dev is the ReadSeeker for the underlying fs device. It is protected by mu.
- //
- // The ext filesystems aim to maximize locality, i.e. place all the data
- // blocks of a file close together. On a spinning disk, locality reduces the
- // amount of movement of the head hence speeding up IO operations. On an SSD
- // there are no moving parts but locality increases the size of each transer
- // request. Hence, having mutual exclusion on the read seeker while reading a
- // file *should* help in achieving the intended performance gains.
- //
- // Note: This synchronization was not coupled with the ReadSeeker itself
- // because we want to synchronize across read/seek operations for the
- // performance gains mentioned above. Helps enforcing one-file-at-a-time IO.
- dev io.ReadSeeker
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// Compiles only if filesystemType implements vfs.FilesystemType.
+var _ vfs.FilesystemType = (*filesystemType)(nil)
+
+// getDeviceFd returns the read seeker to the underlying device.
+// Currently there are two ways of mounting an ext(2/3/4) fs:
+// 1. Specify a mount with our internal special MountType in the OCI spec.
+// 2. Expose the device to the container and mount it from application layer.
+func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReadSeeker, error) {
+ if opts.InternalData == nil {
+ // User mount call.
+ // TODO(b/134676337): Open the device specified by `source` and return that.
+ panic("unimplemented")
+ }
- // inodeCache maps absolute inode numbers to the corresponding Inode struct.
- // Inodes should be removed from this once their reference count hits 0.
- //
- // Protected by mu because every addition and removal from this corresponds to
- // a change in the dentry tree.
- inodeCache map[uint32]*inode
+ // NewFilesystem call originated from within the sentry.
+ fd, ok := opts.InternalData.(uintptr)
+ if !ok {
+ return nil, errors.New("internal data for ext fs must be a uintptr containing the file descriptor to device")
+ }
- // sb represents the filesystem superblock. Immutable after initialization.
- sb disklayout.SuperBlock
+ // We do not close this file because that would close the underlying device
+ // file descriptor (which is required for reading the fs from disk).
+ // TODO(b/134676337): Use pkg/fd instead.
+ deviceFile := os.NewFile(fd, source)
+ if deviceFile == nil {
+ return nil, fmt.Errorf("ext4 device file descriptor is not valid: %d", fd)
+ }
- // bgs represents all the block group descriptors for the filesystem.
- // Immutable after initialization.
- bgs []disklayout.BlockGroup
+ return deviceFile, nil
}
-// newFilesystem is the filesystem constructor.
-func newFilesystem(dev io.ReadSeeker) (*filesystem, error) {
- fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)}
- var err error
+// NewFilesystem implements vfs.FilesystemType.NewFilesystem.
+func (fstype filesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ dev, err := getDeviceFd(source, opts)
+ if err != nil {
+ return nil, nil, err
+ }
+ fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)}
+ fs.vfsfs.Init(&fs)
fs.sb, err = readSuperBlock(dev)
if err != nil {
- return nil, err
+ return nil, nil, err
}
if fs.sb.Magic() != linux.EXT_SUPER_MAGIC {
// mount(2) specifies that EINVAL should be returned if the superblock is
// invalid.
- return nil, syserror.EINVAL
+ return nil, nil, syserror.EINVAL
}
fs.bgs, err = readBlockGroups(dev, fs.sb)
if err != nil {
- return nil, err
- }
-
- return &fs, nil
-}
-
-// getOrCreateInode gets the inode corresponding to the inode number passed in.
-// It creates a new one with the given inode number if one does not exist.
-//
-// Preconditions: must be holding fs.mu.
-func (fs *filesystem) getOrCreateInode(inodeNum uint32) (*inode, error) {
- if in, ok := fs.inodeCache[inodeNum]; ok {
- return in, nil
+ return nil, nil, err
}
- in, err := newInode(fs.dev, fs.sb, fs.bgs, inodeNum)
+ rootInode, err := fs.getOrCreateInode(disklayout.RootDirInode)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- fs.inodeCache[inodeNum] = in
- return in, nil
+ return &fs.vfsfs, &newDentry(rootInode).vfsd, nil
}
diff --git a/pkg/sentry/fs/ext/ext_test.go b/pkg/sentry/fs/ext/ext_test.go
new file mode 100644
index 000000000..ee7f7907c
--- /dev/null
+++ b/pkg/sentry/fs/ext/ext_test.go
@@ -0,0 +1,407 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "fmt"
+ "os"
+ "path"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+
+ "gvisor.dev/gvisor/runsc/test/testutil"
+)
+
+const (
+ assetsDir = "pkg/sentry/fs/ext/assets"
+)
+
+var (
+ ext2ImagePath = path.Join(assetsDir, "tiny.ext2")
+ ext3ImagePath = path.Join(assetsDir, "tiny.ext3")
+ ext4ImagePath = path.Join(assetsDir, "tiny.ext4")
+)
+
+func beginning(_ uint64) uint64 {
+ return 0
+}
+
+func middle(i uint64) uint64 {
+ return i / 2
+}
+
+func end(i uint64) uint64 {
+ return i
+}
+
+// setUp opens imagePath as an ext Filesystem and returns all necessary
+// elements required to run tests. If error is non-nil, it also returns a tear
+// down function which must be called after the test is run for clean up.
+func setUp(t *testing.T, imagePath string) (context.Context, *vfs.Filesystem, *vfs.Dentry, func(), error) {
+ localImagePath, err := testutil.FindFile(imagePath)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to open local image at path %s: %v", imagePath, err)
+ }
+
+ f, err := os.Open(localImagePath)
+ if err != nil {
+ return nil, nil, nil, nil, err
+ }
+
+ // Mount the ext4 fs and retrieve the inode structure for the file.
+ mockCtx := contexttest.Context(t)
+ fs, d, err := filesystemType{}.NewFilesystem(mockCtx, nil, localImagePath, vfs.NewFilesystemOptions{InternalData: f.Fd()})
+ if err != nil {
+ f.Close()
+ return nil, nil, nil, nil, err
+ }
+
+ tearDown := func() {
+ if err := f.Close(); err != nil {
+ t.Fatalf("tearDown failed: %v", err)
+ }
+ }
+ return mockCtx, fs, d, tearDown, nil
+}
+
+// TestRootDir tests that the root directory inode is correctly initialized and
+// returned from setUp.
+func TestRootDir(t *testing.T) {
+ type inodeProps struct {
+ Mode linux.FileMode
+ UID auth.KUID
+ GID auth.KGID
+ Size uint64
+ InodeSize uint16
+ Links uint16
+ Flags disklayout.InodeFlags
+ }
+
+ type rootDirTest struct {
+ name string
+ image string
+ wantInode inodeProps
+ }
+
+ tests := []rootDirTest{
+ {
+ name: "ext4 root dir",
+ image: ext4ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ Flags: disklayout.InodeFlags{Extents: true},
+ },
+ },
+ {
+ name: "ext3 root dir",
+ image: ext3ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ },
+ },
+ {
+ name: "ext2 root dir",
+ image: ext2ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ _, _, vfsd, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ d, ok := vfsd.Impl().(*dentry)
+ if !ok {
+ t.Fatalf("ext dentry of incorrect type: %T", vfsd.Impl())
+ }
+
+ // Offload inode contents into local structs for comparison.
+ gotInode := inodeProps{
+ Mode: d.inode.diskInode.Mode(),
+ UID: d.inode.diskInode.UID(),
+ GID: d.inode.diskInode.GID(),
+ Size: d.inode.diskInode.Size(),
+ InodeSize: d.inode.diskInode.InodeSize(),
+ Links: d.inode.diskInode.LinksCount(),
+ Flags: d.inode.diskInode.Flags(),
+ }
+
+ if diff := cmp.Diff(gotInode, test.wantInode); diff != "" {
+ t.Errorf("inode mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestFilesystemInit tests that the filesystem superblock and block group
+// descriptors are correctly read in and initialized.
+func TestFilesystemInit(t *testing.T) {
+ // sb only contains the immutable properties of the superblock.
+ type sb struct {
+ InodesCount uint32
+ BlocksCount uint64
+ MaxMountCount uint16
+ FirstDataBlock uint32
+ BlockSize uint64
+ BlocksPerGroup uint32
+ ClusterSize uint64
+ ClustersPerGroup uint32
+ InodeSize uint16
+ InodesPerGroup uint32
+ BgDescSize uint16
+ Magic uint16
+ Revision disklayout.SbRevision
+ CompatFeatures disklayout.CompatFeatures
+ IncompatFeatures disklayout.IncompatFeatures
+ RoCompatFeatures disklayout.RoCompatFeatures
+ }
+
+ // bg only contains the immutable properties of the block group descriptor.
+ type bg struct {
+ InodeTable uint64
+ BlockBitmap uint64
+ InodeBitmap uint64
+ ExclusionBitmap uint64
+ Flags disklayout.BGFlags
+ }
+
+ type fsInitTest struct {
+ name string
+ image string
+ wantSb sb
+ wantBgs []bg
+ }
+
+ tests := []fsInitTest{
+ {
+ name: "ext4 filesystem init",
+ image: ext4ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x40,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ Extents: true,
+ Is64Bit: true,
+ FlexBg: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ HugeFile: true,
+ DirNlink: true,
+ ExtraIsize: true,
+ MetadataCsum: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x23,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x13,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ {
+ name: "ext3 filesystem init",
+ image: ext3ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x20,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x5,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x4,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ {
+ name: "ext2 filesystem init",
+ image: ext2ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x20,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x5,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x4,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ _, vfsfs, _, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fs, ok := vfsfs.Impl().(*filesystem)
+ if !ok {
+ t.Fatalf("ext filesystem of incorrect type: %T", vfsfs.Impl())
+ }
+
+ // Offload superblock and block group descriptors contents into
+ // local structs for comparison.
+ totalFreeInodes := uint32(0)
+ totalFreeBlocks := uint64(0)
+ gotSb := sb{
+ InodesCount: fs.sb.InodesCount(),
+ BlocksCount: fs.sb.BlocksCount(),
+ MaxMountCount: fs.sb.MaxMountCount(),
+ FirstDataBlock: fs.sb.FirstDataBlock(),
+ BlockSize: fs.sb.BlockSize(),
+ BlocksPerGroup: fs.sb.BlocksPerGroup(),
+ ClusterSize: fs.sb.ClusterSize(),
+ ClustersPerGroup: fs.sb.ClustersPerGroup(),
+ InodeSize: fs.sb.InodeSize(),
+ InodesPerGroup: fs.sb.InodesPerGroup(),
+ BgDescSize: fs.sb.BgDescSize(),
+ Magic: fs.sb.Magic(),
+ Revision: fs.sb.Revision(),
+ CompatFeatures: fs.sb.CompatibleFeatures(),
+ IncompatFeatures: fs.sb.IncompatibleFeatures(),
+ RoCompatFeatures: fs.sb.ReadOnlyCompatibleFeatures(),
+ }
+ gotNumBgs := len(fs.bgs)
+ gotBgs := make([]bg, gotNumBgs)
+ for i := 0; i < gotNumBgs; i++ {
+ gotBgs[i].InodeTable = fs.bgs[i].InodeTable()
+ gotBgs[i].BlockBitmap = fs.bgs[i].BlockBitmap()
+ gotBgs[i].InodeBitmap = fs.bgs[i].InodeBitmap()
+ gotBgs[i].ExclusionBitmap = fs.bgs[i].ExclusionBitmap()
+ gotBgs[i].Flags = fs.bgs[i].Flags()
+
+ totalFreeInodes += fs.bgs[i].FreeInodesCount()
+ totalFreeBlocks += uint64(fs.bgs[i].FreeBlocksCount())
+ }
+
+ if diff := cmp.Diff(gotSb, test.wantSb); diff != "" {
+ t.Errorf("superblock mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(gotBgs, test.wantBgs); diff != "" {
+ t.Errorf("block group descriptors mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(totalFreeInodes, fs.sb.FreeInodesCount()); diff != "" {
+ t.Errorf("total free inodes mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(totalFreeBlocks, fs.sb.FreeBlocksCount()); diff != "" {
+ t.Errorf("total free blocks mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fs/ext/filesystem.go b/pkg/sentry/fs/ext/filesystem.go
new file mode 100644
index 000000000..7150e75a5
--- /dev/null
+++ b/pkg/sentry/fs/ext/filesystem.go
@@ -0,0 +1,137 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "io"
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ // TODO(b/134676337): Remove when all methods have been implemented.
+ vfs.FilesystemImpl
+
+ vfsfs vfs.Filesystem
+
+ // mu serializes changes to the Dentry tree and the usage of the read seeker.
+ mu sync.Mutex
+
+ // dev is the ReadSeeker for the underlying fs device. It is protected by mu.
+ //
+ // The ext filesystems aim to maximize locality, i.e. place all the data
+ // blocks of a file close together. On a spinning disk, locality reduces the
+ // amount of movement of the head hence speeding up IO operations. On an SSD
+ // there are no moving parts but locality increases the size of each transer
+ // request. Hence, having mutual exclusion on the read seeker while reading a
+ // file *should* help in achieving the intended performance gains.
+ //
+ // Note: This synchronization was not coupled with the ReadSeeker itself
+ // because we want to synchronize across read/seek operations for the
+ // performance gains mentioned above. Helps enforcing one-file-at-a-time IO.
+ dev io.ReadSeeker
+
+ // inodeCache maps absolute inode numbers to the corresponding Inode struct.
+ // Inodes should be removed from this once their reference count hits 0.
+ //
+ // Protected by mu because every addition and removal from this corresponds to
+ // a change in the dentry tree.
+ inodeCache map[uint32]*inode
+
+ // sb represents the filesystem superblock. Immutable after initialization.
+ sb disklayout.SuperBlock
+
+ // bgs represents all the block group descriptors for the filesystem.
+ // Immutable after initialization.
+ bgs []disklayout.BlockGroup
+}
+
+// Compiles only if filesystem implements vfs.FilesystemImpl.
+var _ vfs.FilesystemImpl = (*filesystem)(nil)
+
+// getOrCreateInode gets the inode corresponding to the inode number passed in.
+// It creates a new one with the given inode number if one does not exist.
+//
+// Preconditions: must be holding fs.mu.
+func (fs *filesystem) getOrCreateInode(inodeNum uint32) (*inode, error) {
+ if in, ok := fs.inodeCache[inodeNum]; ok {
+ return in, nil
+ }
+
+ in, err := newInode(fs.dev, fs.sb, fs.bgs, inodeNum)
+ if err != nil {
+ return nil, err
+ }
+
+ fs.inodeCache[inodeNum] = in
+ return in, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // This is a readonly filesystem for now.
+ return nil
+}
+
+// The vfs.FilesystemImpl functions below return EROFS because their respective
+// man pages say that EROFS must be returned if the path resolves to a file on
+// a read-only filesystem.
+
+// TODO(b/134676337): Implement path traversal and return EROFS only if the
+// path resolves to a Dentry within ext fs.
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ return syserror.EROFS
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return syserror.EROFS
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
+ return syserror.EROFS
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ return syserror.EROFS
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ return syserror.EROFS
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return syserror.EROFS
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ return syserror.EROFS
+}
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index e0602da17..246b97161 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -571,12 +571,6 @@ func overlayCheck(ctx context.Context, o *overlayEntry, p PermMask) error {
if o.upper != nil {
err = o.upper.check(ctx, p)
} else {
- if p.Write {
- // Since writes will be redirected to the upper filesystem, the lower
- // filesystem need not be writable, but must be readable for copy-up.
- p.Write = false
- p.Read = true
- }
err = o.lower.check(ctx, p)
}
o.copyMu.RUnlock()
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 70f5a3f0b..4c2d48e65 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -240,6 +240,9 @@ type InitKernelArgs struct {
// RootAbstractSocketNamespace is the root Abstract Socket namespace.
RootAbstractSocketNamespace *AbstractSocketNamespace
+
+ // PIDNamespace is the root PID namespace.
+ PIDNamespace *PIDNamespace
}
// Init initialize the Kernel with no tasks.
@@ -262,7 +265,7 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.featureSet = args.FeatureSet
k.timekeeper = args.Timekeeper
- k.tasks = newTaskSet()
+ k.tasks = newTaskSet(args.PIDNamespace)
k.rootUserNamespace = args.RootUserNamespace
k.rootUTSNamespace = args.RootUTSNamespace
k.rootIPCNamespace = args.RootIPCNamespace
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index b21b182fc..8267929a6 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -90,9 +90,9 @@ type TaskSet struct {
}
// newTaskSet returns a new, empty TaskSet.
-func newTaskSet() *TaskSet {
- ts := &TaskSet{}
- ts.Root = newPIDNamespace(ts, nil /* parent */, auth.NewRootUserNamespace())
+func newTaskSet(pidns *PIDNamespace) *TaskSet {
+ ts := &TaskSet{Root: pidns}
+ pidns.owner = ts
return ts
}
@@ -186,6 +186,12 @@ func newPIDNamespace(ts *TaskSet, parent *PIDNamespace, userns *auth.UserNamespa
}
}
+// NewRootPIDNamespace creates the root PID namespace. 'owner' is not available
+// yet when root namespace is created and must be set by caller.
+func NewRootPIDNamespace(userns *auth.UserNamespace) *PIDNamespace {
+ return newPIDNamespace(nil, nil, userns)
+}
+
// NewChild returns a new, empty PID namespace that is a child of ns. Authority
// over the new PID namespace is controlled by userns.
func (ns *PIDNamespace) NewChild(userns *auth.UserNamespace) *PIDNamespace {
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go
index 7eef19f74..8fe489c0e 100644
--- a/pkg/sentry/socket/epsocket/stack.go
+++ b/pkg/sentry/socket/epsocket/stack.go
@@ -75,8 +75,8 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
addrs = append(addrs, inet.InterfaceAddr{
Family: family,
- PrefixLen: uint8(len(a.Address) * 8),
- Addr: []byte(a.Address),
+ PrefixLen: uint8(a.AddressWithPrefix.PrefixLen),
+ Addr: []byte(a.AddressWithPrefix.Address),
// TODO(b/68878065): Other fields.
})
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index cb35635fc..60070874d 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -46,7 +46,6 @@ const (
// endpoint implements stack.NetworkEndpoint.
type endpoint struct {
nicid tcpip.NICID
- addr tcpip.Address
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
}
@@ -73,6 +72,10 @@ func (e *endpoint) ID() *stack.NetworkEndpointID {
return &stack.NetworkEndpointID{ProtocolAddress}
}
+func (e *endpoint) PrefixLen() int {
+ return 0
+}
+
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
@@ -122,19 +125,19 @@ type protocol struct {
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
+func (p *protocol) DefaultPrefixLen() int { return 0 }
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.ARP(v)
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
- if addr != ProtocolAddress {
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ if addrWithPrefix.Address != ProtocolAddress {
return nil, tcpip.ErrBadLocalAddress
}
return &endpoint{
nicid: nicid,
- addr: addr,
linkEP: sender,
linkAddrCache: linkAddrCache,
}, nil
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 8ff428445..55e9eec99 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -29,16 +29,18 @@ import (
)
const (
- localIpv4Addr = "\x0a\x00\x00\x01"
- remoteIpv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ localIpv4Addr = "\x0a\x00\x00\x01"
+ localIpv4PrefixLen = 24
+ remoteIpv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ localIpv6PrefixLen = 120
+ remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
)
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
@@ -197,7 +199,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
func TestIPv4Send(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, nil, &o)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -229,7 +231,7 @@ func TestIPv4Send(t *testing.T) {
func TestIPv4Receive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -295,7 +297,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -359,7 +361,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
func TestIPv4FragmentationReceive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -426,7 +428,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
func TestIPv6Send(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, nil, &o)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -458,7 +460,7 @@ func TestIPv6Send(t *testing.T) {
func TestIPv6Receive(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -532,7 +534,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e44a73d96..b7a06f525 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -49,16 +49,18 @@ const (
type endpoint struct {
nicid tcpip.NICID
id stack.NetworkEndpointID
+ prefixLen int
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
e := &endpoint{
nicid: nicid,
- id: stack.NetworkEndpointID{LocalAddress: addr},
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
@@ -93,6 +95,11 @@ func (e *endpoint) ID() *stack.NetworkEndpointID {
return &e.id
}
+// PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
+func (e *endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -338,6 +345,11 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
+// DefaultPrefixLen returns the IPv4 default prefix length.
+func (p *protocol) DefaultPrefixLen() int {
+ return header.IPv4AddressSize * 8
+}
+
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d46d68e73..726362c87 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -99,7 +99,11 @@ func TestICMPCounts(t *testing.T) {
}},
)
- ep, err := s.NetworkProtocolInstance(ProtocolNumber).NewEndpoint(0, lladdr1, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
if err != nil {
t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index e3e8739fd..331a8bdaa 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -46,6 +46,7 @@ const (
type endpoint struct {
nicid tcpip.NICID
id stack.NetworkEndpointID
+ prefixLen int
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
@@ -72,6 +73,11 @@ func (e *endpoint) ID() *stack.NetworkEndpointID {
return &e.id
}
+// PrefixLen returns the ipv6 endpoint subnet prefix length in bits.
+func (e *endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
// Capabilities implements stack.NetworkEndpoint.Capabilities.
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
@@ -172,6 +178,11 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
+// DefaultPrefixLen returns the IPv6 default prefix length.
+func (p *protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
@@ -179,10 +190,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &endpoint{
nicid: nicid,
- id: stack.NetworkEndpointID{LocalAddress: addr},
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 30c0dee42..3e6ff4afb 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -129,7 +129,7 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
n.mu.RLock()
defer n.mu.RUnlock()
@@ -148,21 +148,16 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add
}
if r == nil {
- return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress
+ return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress
}
- address := r.ep.ID().LocalAddress
+ addressWithPrefix := tcpip.AddressWithPrefix{
+ Address: r.ep.ID().LocalAddress,
+ PrefixLen: r.ep.PrefixLen(),
+ }
r.decRef()
- // Find the least-constrained matching subnet for the address, if one
- // exists, and return it.
- var subnet tcpip.Subnet
- for _, s := range n.subnets {
- if s.Contains(address) && !subnet.Contains(s.ID()) {
- subnet = s
- }
- }
- return address, subnet, nil
+ return addressWithPrefix, nil
}
// primaryEndpoint returns the primary endpoint of n for the given network
@@ -213,23 +208,26 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
n.mu.Lock()
ref = n.endpoints[id]
if ref == nil || !ref.tryIncRef() {
- ref, _ = n.addAddressLocked(protocol, address, peb, true)
- if ref != nil {
- ref.holdsInsertRef = false
+ if netProto, ok := n.stack.networkProtocols[protocol]; ok {
+ addrWithPrefix := tcpip.AddressWithPrefix{address, netProto.DefaultPrefixLen()}
+ ref, _ = n.addAddressLocked(protocol, addrWithPrefix, peb, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
}
}
n.mu.Unlock()
return ref
}
-func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
}
// Create the new network endpoint.
- ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
+ ep, err := netProto.NewEndpoint(n.id, addrWithPrefix, n.stack, n, n.linkEP)
if err != nil {
return nil, err
}
@@ -278,16 +276,10 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.
// AddAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
- return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify
-// whether the new endpoint can be primary or not.
-func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addAddressLocked(protocol, addr, peb, false)
+ _, err := n.addAddressLocked(protocol, addrWithPrefix, peb, false)
n.mu.Unlock()
return err
@@ -298,10 +290,13 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
- for nid, ep := range n.endpoints {
+ for nid, ref := range n.endpoints {
addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: ep.protocol,
- Address: nid.LocalAddress,
+ Protocol: ref.protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: nid.LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ },
})
}
return addrs
@@ -415,7 +410,12 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
- if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()}
+ if _, err := n.addAddressLocked(protocol, addrWithPrefix, NeverPrimaryEndpoint, false); err != nil {
return err
}
}
@@ -572,7 +572,13 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r
n.mu.Unlock()
return ref
}
- ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true)
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ n.mu.Unlock()
+ return nil
+ }
+ addrWithPrefix := tcpip.AddressWithPrefix{dst, netProto.DefaultPrefixLen()}
+ ref, err := n.addAddressLocked(protocol, addrWithPrefix, CanBePrimaryEndpoint, true)
n.mu.Unlock()
if err == nil {
ref.holdsInsertRef = false
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 462265281..2037eef9f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -181,6 +181,9 @@ type NetworkEndpoint interface {
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
+ // PrefixLen returns the network endpoint's subnet prefix length in bits.
+ PrefixLen() int
+
// NICID returns the id of the NIC this endpoint belongs to.
NICID() tcpip.NICID
@@ -203,12 +206,15 @@ type NetworkProtocol interface {
// than this targeted at this protocol.
MinimumPacketSize() int
+ // DefaultPrefixLen returns the protocol's default prefix length.
+ DefaultPrefixLen() int
+
// ParsePorts returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+ NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3e8fb2a6c..57b8a9994 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -751,9 +751,26 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
+// AddAddressWithPrefix adds a new network-layer address/prefixLen to the
+// specified NIC.
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix) *tcpip.Error {
+ return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, CanBePrimaryEndpoint)
+}
+
// AddAddressWithOptions is the same as AddAddress, but allows you to specify
// whether the new endpoint can be primary or not.
func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+ netProto, ok := s.networkProtocols[protocol]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()}
+ return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, peb)
+}
+
+// AddAddressWithPrefixAndOptions is the same as AddAddressWithPrefixLen,
+// but allows you to specify whether the new endpoint can be primary or not.
+func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -762,7 +779,7 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt
return tcpip.ErrUnknownNICID
}
- return nic.AddAddressWithOptions(protocol, addr, peb)
+ return nic.AddAddress(protocol, addrWithPrefix, peb)
}
// AddSubnet adds a subnet range to the specified NIC.
@@ -821,7 +838,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
// address if no primary addresses exist. Returns an error if the NIC doesn't
// exist or has no endpoints.
-func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -829,7 +846,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return nic.getMainNICAddress(protocol)
}
- return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID
+ return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 959071dbe..9d082bba4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"math"
+ "sort"
"strings"
"testing"
@@ -32,8 +33,9 @@ import (
)
const (
- fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
- fakeNetHeaderLen = 12
+ fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fakeNetHeaderLen = 12
+ fakeDefaultPrefixLen = 8
// fakeControlProtocol is used for control packets that represent
// destination port unreachable.
@@ -55,6 +57,7 @@ const (
type fakeNetworkEndpoint struct {
nicid tcpip.NICID
id stack.NetworkEndpointID
+ prefixLen int
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
linkEP stack.LinkEndpoint
@@ -68,6 +71,10 @@ func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
return f.nicid
}
+func (f *fakeNetworkEndpoint) PrefixLen() int {
+ return f.prefixLen
+}
+
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
@@ -170,14 +177,19 @@ func (f *fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
+func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
+ return fakeDefaultPrefixLen
+}
+
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
- id: stack.NetworkEndpointID{addr},
+ id: stack.NetworkEndpointID{addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
linkEP: linkEP,
@@ -212,15 +224,15 @@ func TestNetworkReceive(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -280,13 +292,13 @@ func TestNetworkReceive(t *testing.T) {
func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
+ t.Fatal("FindRoute failed:", err)
}
defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil {
- t.Errorf("WritePacket failed: %v", err)
+ t.Error("WritePacket failed:", err)
}
}
@@ -297,13 +309,13 @@ func TestNetworkSend(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("NewNIC failed: %v", err)
+ t.Fatal("NewNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
// Make sure that the link-layer endpoint received the outbound packet.
@@ -321,28 +333,28 @@ func TestNetworkSendMultiRoute(t *testing.T) {
id1, linkEP1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
id2, linkEP2 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(2, id2); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
// Set a route table that sends all packets with odd destination
@@ -371,7 +383,7 @@ func TestNetworkSendMultiRoute(t *testing.T) {
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
+ t.Fatal("FindRoute failed:", err)
}
defer r.Release()
@@ -388,7 +400,7 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr,
func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
_, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err)
+ t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute)
}
}
@@ -400,28 +412,28 @@ func TestRoutes(t *testing.T) {
id1, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
id2, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(2, id2); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
// Set a route table that sends all packets with odd destination
@@ -464,11 +476,11 @@ func TestAddressRemoval(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -486,7 +498,7 @@ func TestAddressRemoval(t *testing.T) {
// Remove the address, then check that packet doesn't get delivered
// anymore.
if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
+ t.Fatal("RemoveAddress failed:", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
@@ -496,7 +508,7 @@ func TestAddressRemoval(t *testing.T) {
// Check that removing the same address fails.
if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress failed: %v", err)
+ t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
}
@@ -505,11 +517,11 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -531,7 +543,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
// Get a route, check that packet is still deliverable.
r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
+ t.Fatal("FindRoute failed:", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
@@ -542,7 +554,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
// Remove the address, then check that packet is still deliverable
// because the route is keeping the address alive.
if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
+ t.Fatal("RemoveAddress failed:", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
@@ -552,7 +564,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
// Check that removing the same address fails.
if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress failed: %v", err)
+ t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
// Release the route, then check that packet is not deliverable anymore.
@@ -568,7 +580,7 @@ func TestPromiscuousMode(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -590,7 +602,7 @@ func TestPromiscuousMode(t *testing.T) {
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
- t.Fatalf("SetPromiscuousMode failed: %v", err)
+ t.Fatal("SetPromiscuousMode failed:", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
@@ -601,13 +613,13 @@ func TestPromiscuousMode(t *testing.T) {
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err)
+ t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute)
}
// Set promiscuous mode to false, then check that packet can't be
// delivered anymore.
if err := s.SetPromiscuousMode(1, false); err != nil {
- t.Fatalf("SetPromiscuousMode failed: %v", err)
+ t.Fatal("SetPromiscuousMode failed:", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
@@ -624,11 +636,11 @@ func TestAddressSpoofing(t *testing.T) {
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -645,11 +657,11 @@ func TestAddressSpoofing(t *testing.T) {
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
if err := s.SetSpoofing(1, true); err != nil {
- t.Fatalf("SetSpoofing failed: %v", err)
+ t.Fatal("SetSpoofing failed:", err)
}
r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
+ t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != srcAddr {
t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
@@ -664,17 +676,17 @@ func TestBroadcastNeedsNoRoute(t *testing.T) {
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{})
// If there is no endpoint, it won't work.
if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, header.IPv4Any, err)
+ t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
@@ -735,7 +747,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{})
@@ -791,7 +803,7 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -806,10 +818,10 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
fakeNet.packetCount[1] = 0
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
+ t.Fatal("NewSubnet failed:", err)
}
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
+ t.Fatal("AddSubnet failed:", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
@@ -824,7 +836,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -839,10 +851,10 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
fakeNet.packetCount[1] = 0
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
+ t.Fatal("NewSubnet failed:", err)
}
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
+ t.Fatal("AddSubnet failed:", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
@@ -894,38 +906,38 @@ func TestSubnetAddRemove(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
addr := tcpip.Address("\x01\x01\x01\x01")
mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
subnet, err := tcpip.NewSubnet(addr, mask)
if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
+ t.Fatal("NewSubnet failed:", err)
}
if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatalf("ContainsSubnet failed: %v", err)
+ t.Fatal("ContainsSubnet failed:", err)
} else if contained {
t.Fatal("got s.ContainsSubnet(...) = true, want = false")
}
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
+ t.Fatal("AddSubnet failed:", err)
}
if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatalf("ContainsSubnet failed: %v", err)
+ t.Fatal("ContainsSubnet failed:", err)
} else if !contained {
t.Fatal("got s.ContainsSubnet(...) = false, want = true")
}
if err := s.RemoveSubnet(1, subnet); err != nil {
- t.Fatalf("RemoveSubnet failed: %v", err)
+ t.Fatal("RemoveSubnet failed:", err)
}
if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatalf("ContainsSubnet failed: %v", err)
+ t.Fatal("ContainsSubnet failed:", err)
} else if contained {
t.Fatal("got s.ContainsSubnet(...) = true, want = false")
}
@@ -941,11 +953,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
// Insert <canBe> primary and <never> never-primary addresses.
// Each one will add a network endpoint to the NIC.
- primaryAddrAdded := make(map[tcpip.Address]tcpip.Subnet)
+ primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{})
for i := 0; i < canBe+never; i++ {
var behavior stack.PrimaryEndpointBehavior
if i < canBe {
@@ -953,46 +965,39 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
} else {
behavior = stack.NeverPrimaryEndpoint
}
- // Add an address and in case of a primary one also add a
- // subnet.
+ // Add an address and in case of a primary one include a
+ // prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
- if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions failed: %v", err)
- }
if behavior == stack.CanBePrimaryEndpoint {
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(address)))
- subnet, err := tcpip.NewSubnet(address, mask)
- if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
+ addressWithPrefix := tcpip.AddressWithPrefix{address, addrLen * 8}
+ if err := s.AddAddressWithPrefixAndOptions(1, fakeNetNumber, addressWithPrefix, behavior); err != nil {
+ t.Fatal("AddAddressWithPrefixAndOptions failed:", err)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
+ // Remember the address/prefix.
+ primaryAddrAdded[addressWithPrefix] = struct{}{}
+ } else {
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
}
- // Remember the address/subnet.
- primaryAddrAdded[address] = subnet
}
}
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
- // address/subnet matches what we added.
+ // address/prefixLen matches what we added.
if len(primaryAddrAdded) == 0 {
// No primary addresses present, expect an error.
- if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %v", err, tcpip.ErrNoLinkAddress)
+ if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress)
}
} else {
// At least one primary address was added, expect a valid
- // address and subnet.
- gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber)
+ // address and prefixLen.
+ gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber)
if err != nil {
- t.Fatalf("GetMainNICAddress failed: %v", err)
- }
- expectedSubnet, ok := primaryAddrAdded[gotAddress]
- if !ok {
- t.Fatalf("GetMainNICAddress: got address = %v, wanted any in {%v}", gotAddress, primaryAddrAdded)
+ t.Fatal("GetMainNICAddress failed:", err)
}
- if gotSubnet != expectedSubnet {
- t.Fatalf("GetMainNICAddress: got subnet = %v, wanted %v", gotSubnet, expectedSubnet)
+ if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok {
+ t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded)
}
}
})
@@ -1007,65 +1012,194 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
for _, tc := range []struct {
- name string
- address tcpip.Address
+ name string
+ address tcpip.Address
+ prefixLen int
}{
- {"IPv4", "\x01\x01\x01\x01"},
- {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
+ {"IPv4", "\x01\x01\x01\x01", 24},
+ {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116},
} {
t.Run(tc.name, func(t *testing.T) {
- address := tc.address
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(address)))
- subnet, err := tcpip.NewSubnet(address, mask)
- if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
+ addressWithPrefix := tcpip.AddressWithPrefix{tc.address, tc.prefixLen}
+
+ if err := s.AddAddressWithPrefix(1, fakeNetNumber, addressWithPrefix); err != nil {
+ t.Fatal("AddAddressWithPrefix failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ // Check that we get the right initial address and prefix length.
+ if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ } else if gotAddressWithPrefix != addressWithPrefix {
+ t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, addressWithPrefix)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed: %v", err)
+ if err := s.RemoveAddress(1, addressWithPrefix.Address); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
}
- // Check that we get the right initial address and subnet.
- if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
- t.Fatalf("GetMainNICAddress failed: %v", err)
- } else if gotAddress != address {
- t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address)
- } else if gotSubnet != subnet {
- t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet)
+ // Check that we get an error after removal.
+ if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress)
}
+ })
+ }
+}
+
+// Simple network address generator. Good for 255 addresses.
+type addressGenerator struct{ cnt byte }
+
+func (g *addressGenerator) next(addrLen int) tcpip.Address {
+ g.cnt++
+ return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen))
+}
+
+func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
+ if len(gotAddresses) != len(expectedAddresses) {
+ t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses))
+ }
+
+ sort.Slice(gotAddresses, func(i, j int) bool {
+ return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address
+ })
+ sort.Slice(expectedAddresses, func(i, j int) bool {
+ return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address
+ })
+
+ for i, gotAddr := range gotAddresses {
+ expectedAddr := expectedAddresses[i]
+ if gotAddr != expectedAddr {
+ t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr)
+ }
+ }
+}
+
+func TestAddAddress(t *testing.T) {
+ const nicid = 1
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
- if err := s.RemoveSubnet(1, subnet); err != nil {
- t.Fatalf("RemoveSubnet failed: %v", err)
+ var addrGen addressGenerator
+ expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
+ for _, addrLen := range []int{4, 16} {
+ address := addrGen.next(addrLen)
+ if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil {
+ t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
+ }
+ expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
+ })
+ }
+
+ gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ verifyAddresses(t, expectedAddresses, gotAddresses)
+}
+
+func TestAddAddressWithPrefix(t *testing.T) {
+ const nicid = 1
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ var addrGen addressGenerator
+ addrLenRange := []int{4, 16}
+ prefixLenRange := []int{8, 13, 20, 32}
+ expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
+ for _, addrLen := range addrLenRange {
+ for _, prefixLen := range prefixLenRange {
+ address := addrGen.next(addrLen)
+ if err := s.AddAddressWithPrefix(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}); err != nil {
+ t.Errorf("AddAddressWithPrefix(address=%s, prefixLen=%d) failed: %s", address, prefixLen, err)
}
+ expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen},
+ })
+ }
+ }
- if err := s.RemoveAddress(1, address); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
+ gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ verifyAddresses(t, expectedAddresses, gotAddresses)
+}
+
+func TestAddAddressWithOptions(t *testing.T) {
+ const nicid = 1
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ addrLenRange := []int{4, 16}
+ behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
+ expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
+ var addrGen addressGenerator
+ for _, addrLen := range addrLenRange {
+ for _, behavior := range behaviorRange {
+ address := addrGen.next(addrLen)
+ if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil {
+ t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
}
+ expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
+ })
+ }
+ }
- // Check that we get an error after removal.
- if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress)
+ gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ verifyAddresses(t, expectedAddresses, gotAddresses)
+}
+
+func TestAddAddressWithPrefixAndOptions(t *testing.T) {
+ const nicid = 1
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ addrLenRange := []int{4, 16}
+ prefixLenRange := []int{8, 13, 20, 32}
+ behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
+ expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
+ var addrGen addressGenerator
+ for _, addrLen := range addrLenRange {
+ for _, prefixLen := range prefixLenRange {
+ for _, behavior := range behaviorRange {
+ address := addrGen.next(addrLen)
+ if err := s.AddAddressWithPrefixAndOptions(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}, behavior); err != nil {
+ t.Fatalf("AddAddressWithPrefixAndOptions(address=%s, prefixLen=%d, behavior=%d) failed: %s", address, prefixLen, behavior, err)
+ }
+ expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen},
+ })
}
- })
+ }
}
+
+ gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestNICStats(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, linkEP1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatal("CreateNIC failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatal("AddAddress failed:", err)
}
// Route all packets for address \x01 to NIC 1.
s.SetRouteTable([]tcpip.Route{
@@ -1104,18 +1238,18 @@ func TestNICForwarding(t *testing.T) {
id1, linkEP1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
- t.Fatalf("CreateNIC #1 failed: %v", err)
+ t.Fatal("CreateNIC #1 failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress #1 failed: %v", err)
+ t.Fatal("AddAddress #1 failed:", err)
}
id2, linkEP2 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(2, id2); err != nil {
- t.Fatalf("CreateNIC #2 failed: %v", err)
+ t.Fatal("CreateNIC #2 failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress #2 failed: %v", err)
+ t.Fatal("AddAddress #2 failed:", err)
}
// Route all packets to address 3 to NIC 2.
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index c5d79da5e..4208c0303 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1059,14 +1059,23 @@ func ParseMACAddress(s string) (LinkAddress, error) {
return LinkAddress(addr), nil
}
+// AddressWithPrefix is an address with its subnet prefix length.
+type AddressWithPrefix struct {
+ // Address is a network address.
+ Address Address
+
+ // PrefixLen is the subnet prefix length.
+ PrefixLen int
+}
+
// ProtocolAddress is an address and the network protocol it is associated
// with.
type ProtocolAddress struct {
// Protocol is the protocol of the address.
Protocol NetworkProtocolNumber
- // Address is a network address.
- Address Address
+ // AddressWithPrefix is a network address with its subnet prefix length.
+ AddressWithPrefix AddressWithPrefix
}
// danglingEndpointsMu protects access to danglingEndpoints.
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 958d5712e..56c285f88 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -967,14 +967,14 @@ func TestTTL(t *testing.T) {
multicast = false
switch variant {
case "v4", "mapped":
- ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
wantTTL = ep.DefaultTTL()
ep.Close()
case "v6":
- ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
if err != nil {
t.Fatal(err)
}