diff options
Diffstat (limited to 'pkg')
25 files changed, 1064 insertions, 256 deletions
diff --git a/pkg/sentry/fs/ext/BUILD b/pkg/sentry/fs/ext/BUILD index 3ba278e08..2c15875f5 100644 --- a/pkg/sentry/fs/ext/BUILD +++ b/pkg/sentry/fs/ext/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "dentry.go", "ext.go", + "filesystem.go", "inode.go", "utils.go", ], @@ -15,7 +16,10 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/sentry/context", "//pkg/sentry/fs/ext/disklayout", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", "//pkg/syserror", ], ) @@ -23,11 +27,27 @@ go_library( go_test( name = "ext_test", size = "small", - srcs = ["extent_test.go"], + srcs = [ + "ext_test.go", + "extent_test.go", + ], + data = [ + "//pkg/sentry/fs/ext:assets/bigfile.txt", + "//pkg/sentry/fs/ext:assets/file.txt", + "//pkg/sentry/fs/ext:assets/tiny.ext2", + "//pkg/sentry/fs/ext:assets/tiny.ext3", + "//pkg/sentry/fs/ext:assets/tiny.ext4", + ], embed = [":ext"], deps = [ + "//pkg/abi/linux", "//pkg/binary", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", "//pkg/sentry/fs/ext/disklayout", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", + "//runsc/test/testutil", "@com_github_google_go-cmp//cmp:go_default_library", "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/sentry/fs/ext/dentry.go b/pkg/sentry/fs/ext/dentry.go index 71cd217df..054fb42b6 100644 --- a/pkg/sentry/fs/ext/dentry.go +++ b/pkg/sentry/fs/ext/dentry.go @@ -14,10 +14,43 @@ package ext +import ( + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + // dentry implements vfs.DentryImpl. type dentry struct { + vfsd vfs.Dentry + // inode is the inode represented by this dentry. Multiple Dentries may // share a single non-directory Inode (with hard links). inode is // immutable. inode *inode } + +// Compiles only if dentry implements vfs.DentryImpl. +var _ vfs.DentryImpl = (*dentry)(nil) + +// newDentry is the dentry constructor. +func newDentry(in *inode) *dentry { + d := &dentry{ + inode: in, + } + d.vfsd.Init(d) + return d +} + +// IncRef implements vfs.DentryImpl.IncRef. +func (d *dentry) IncRef(vfsfs *vfs.Filesystem) { + d.inode.incRef() +} + +// TryIncRef implements vfs.DentryImpl.TryIncRef. +func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { + return d.inode.tryIncRef() +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *dentry) DecRef(vfsfs *vfs.Filesystem) { + d.inode.decRef(vfsfs.Impl().(*filesystem)) +} diff --git a/pkg/sentry/fs/ext/disklayout/dirent_old.go b/pkg/sentry/fs/ext/disklayout/dirent_old.go index 2e0f9c812..6fff12a6e 100644 --- a/pkg/sentry/fs/ext/disklayout/dirent_old.go +++ b/pkg/sentry/fs/ext/disklayout/dirent_old.go @@ -17,8 +17,7 @@ package disklayout import "gvisor.dev/gvisor/pkg/sentry/fs" // DirentOld represents the old directory entry struct which does not contain -// the file type. This emulates Linux's ext4_dir_entry struct. This is used in -// ext2, ext3 and sometimes in ext4. +// the file type. This emulates Linux's ext4_dir_entry struct. // // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain diff --git a/pkg/sentry/fs/ext/disklayout/inode.go b/pkg/sentry/fs/ext/disklayout/inode.go index 9ab9a4988..88ae913f5 100644 --- a/pkg/sentry/fs/ext/disklayout/inode.go +++ b/pkg/sentry/fs/ext/disklayout/inode.go @@ -20,6 +20,12 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) +// Special inodes. See https://www.kernel.org/doc/html/latest/filesystems/ext4/overview.html#special-inodes. +const ( + // RootDirInode is the inode number of the root directory inode. + RootDirInode = 2 +) + // The Inode interface must be implemented by structs representing ext inodes. // The inode stores all the metadata pertaining to the file (except for the // file name which is held by the directory entry). It does NOT expose all diff --git a/pkg/sentry/fs/ext/disklayout/superblock_32.go b/pkg/sentry/fs/ext/disklayout/superblock_32.go index 587e4afaa..53e515fd3 100644 --- a/pkg/sentry/fs/ext/disklayout/superblock_32.go +++ b/pkg/sentry/fs/ext/disklayout/superblock_32.go @@ -15,7 +15,8 @@ package disklayout // SuperBlock32Bit implements SuperBlock and represents the 32-bit version of -// the ext4_super_block struct in fs/ext4/ext4.h. +// the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if +// RevLevel = DynamicRev and 64-bit feature is disabled. type SuperBlock32Bit struct { // We embed the old superblock struct here because the 32-bit version is just // an extension of the old version. diff --git a/pkg/sentry/fs/ext/disklayout/superblock_64.go b/pkg/sentry/fs/ext/disklayout/superblock_64.go index a2c2278fb..7c1053fb4 100644 --- a/pkg/sentry/fs/ext/disklayout/superblock_64.go +++ b/pkg/sentry/fs/ext/disklayout/superblock_64.go @@ -17,7 +17,8 @@ package disklayout // SuperBlock64Bit implements SuperBlock and represents the 64-bit version of // the ext4_super_block struct in fs/ext4/ext4.h. This sums up to be exactly // 1024 bytes (smallest possible block size) and hence the superblock always -// fits in no more than one data block. +// fits in no more than one data block. Should only be used when the 64-bit +// feature is set. type SuperBlock64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fs/ext/disklayout/superblock_old.go b/pkg/sentry/fs/ext/disklayout/superblock_old.go index 5a64aaaa1..9221e0251 100644 --- a/pkg/sentry/fs/ext/disklayout/superblock_old.go +++ b/pkg/sentry/fs/ext/disklayout/superblock_old.go @@ -15,7 +15,7 @@ package disklayout // SuperBlockOld implements SuperBlock and represents the old version of the -// superblock struct in ext2 and ext3 systems. +// superblock struct. Should be used only if RevLevel = OldRev. type SuperBlockOld struct { InodesCountRaw uint32 BlocksCountLo uint32 diff --git a/pkg/sentry/fs/ext/ext.go b/pkg/sentry/fs/ext/ext.go index 7f4287b01..10e235fb1 100644 --- a/pkg/sentry/fs/ext/ext.go +++ b/pkg/sentry/fs/ext/ext.go @@ -16,86 +16,82 @@ package ext import ( + "errors" + "fmt" "io" - "sync" + "os" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) -// filesystem implements vfs.FilesystemImpl. -type filesystem struct { - // mu serializes changes to the Dentry tree and the usage of the read seeker. - mu sync.Mutex - - // dev is the ReadSeeker for the underlying fs device. It is protected by mu. - // - // The ext filesystems aim to maximize locality, i.e. place all the data - // blocks of a file close together. On a spinning disk, locality reduces the - // amount of movement of the head hence speeding up IO operations. On an SSD - // there are no moving parts but locality increases the size of each transer - // request. Hence, having mutual exclusion on the read seeker while reading a - // file *should* help in achieving the intended performance gains. - // - // Note: This synchronization was not coupled with the ReadSeeker itself - // because we want to synchronize across read/seek operations for the - // performance gains mentioned above. Helps enforcing one-file-at-a-time IO. - dev io.ReadSeeker +// filesystemType implements vfs.FilesystemType. +type filesystemType struct{} + +// Compiles only if filesystemType implements vfs.FilesystemType. +var _ vfs.FilesystemType = (*filesystemType)(nil) + +// getDeviceFd returns the read seeker to the underlying device. +// Currently there are two ways of mounting an ext(2/3/4) fs: +// 1. Specify a mount with our internal special MountType in the OCI spec. +// 2. Expose the device to the container and mount it from application layer. +func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReadSeeker, error) { + if opts.InternalData == nil { + // User mount call. + // TODO(b/134676337): Open the device specified by `source` and return that. + panic("unimplemented") + } - // inodeCache maps absolute inode numbers to the corresponding Inode struct. - // Inodes should be removed from this once their reference count hits 0. - // - // Protected by mu because every addition and removal from this corresponds to - // a change in the dentry tree. - inodeCache map[uint32]*inode + // NewFilesystem call originated from within the sentry. + fd, ok := opts.InternalData.(uintptr) + if !ok { + return nil, errors.New("internal data for ext fs must be a uintptr containing the file descriptor to device") + } - // sb represents the filesystem superblock. Immutable after initialization. - sb disklayout.SuperBlock + // We do not close this file because that would close the underlying device + // file descriptor (which is required for reading the fs from disk). + // TODO(b/134676337): Use pkg/fd instead. + deviceFile := os.NewFile(fd, source) + if deviceFile == nil { + return nil, fmt.Errorf("ext4 device file descriptor is not valid: %d", fd) + } - // bgs represents all the block group descriptors for the filesystem. - // Immutable after initialization. - bgs []disklayout.BlockGroup + return deviceFile, nil } -// newFilesystem is the filesystem constructor. -func newFilesystem(dev io.ReadSeeker) (*filesystem, error) { - fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} - var err error +// NewFilesystem implements vfs.FilesystemType.NewFilesystem. +func (fstype filesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + dev, err := getDeviceFd(source, opts) + if err != nil { + return nil, nil, err + } + fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} + fs.vfsfs.Init(&fs) fs.sb, err = readSuperBlock(dev) if err != nil { - return nil, err + return nil, nil, err } if fs.sb.Magic() != linux.EXT_SUPER_MAGIC { // mount(2) specifies that EINVAL should be returned if the superblock is // invalid. - return nil, syserror.EINVAL + return nil, nil, syserror.EINVAL } fs.bgs, err = readBlockGroups(dev, fs.sb) if err != nil { - return nil, err - } - - return &fs, nil -} - -// getOrCreateInode gets the inode corresponding to the inode number passed in. -// It creates a new one with the given inode number if one does not exist. -// -// Preconditions: must be holding fs.mu. -func (fs *filesystem) getOrCreateInode(inodeNum uint32) (*inode, error) { - if in, ok := fs.inodeCache[inodeNum]; ok { - return in, nil + return nil, nil, err } - in, err := newInode(fs.dev, fs.sb, fs.bgs, inodeNum) + rootInode, err := fs.getOrCreateInode(disklayout.RootDirInode) if err != nil { - return nil, err + return nil, nil, err } - fs.inodeCache[inodeNum] = in - return in, nil + return &fs.vfsfs, &newDentry(rootInode).vfsd, nil } diff --git a/pkg/sentry/fs/ext/ext_test.go b/pkg/sentry/fs/ext/ext_test.go new file mode 100644 index 000000000..ee7f7907c --- /dev/null +++ b/pkg/sentry/fs/ext/ext_test.go @@ -0,0 +1,407 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "os" + "path" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + + "gvisor.dev/gvisor/runsc/test/testutil" +) + +const ( + assetsDir = "pkg/sentry/fs/ext/assets" +) + +var ( + ext2ImagePath = path.Join(assetsDir, "tiny.ext2") + ext3ImagePath = path.Join(assetsDir, "tiny.ext3") + ext4ImagePath = path.Join(assetsDir, "tiny.ext4") +) + +func beginning(_ uint64) uint64 { + return 0 +} + +func middle(i uint64) uint64 { + return i / 2 +} + +func end(i uint64) uint64 { + return i +} + +// setUp opens imagePath as an ext Filesystem and returns all necessary +// elements required to run tests. If error is non-nil, it also returns a tear +// down function which must be called after the test is run for clean up. +func setUp(t *testing.T, imagePath string) (context.Context, *vfs.Filesystem, *vfs.Dentry, func(), error) { + localImagePath, err := testutil.FindFile(imagePath) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to open local image at path %s: %v", imagePath, err) + } + + f, err := os.Open(localImagePath) + if err != nil { + return nil, nil, nil, nil, err + } + + // Mount the ext4 fs and retrieve the inode structure for the file. + mockCtx := contexttest.Context(t) + fs, d, err := filesystemType{}.NewFilesystem(mockCtx, nil, localImagePath, vfs.NewFilesystemOptions{InternalData: f.Fd()}) + if err != nil { + f.Close() + return nil, nil, nil, nil, err + } + + tearDown := func() { + if err := f.Close(); err != nil { + t.Fatalf("tearDown failed: %v", err) + } + } + return mockCtx, fs, d, tearDown, nil +} + +// TestRootDir tests that the root directory inode is correctly initialized and +// returned from setUp. +func TestRootDir(t *testing.T) { + type inodeProps struct { + Mode linux.FileMode + UID auth.KUID + GID auth.KGID + Size uint64 + InodeSize uint16 + Links uint16 + Flags disklayout.InodeFlags + } + + type rootDirTest struct { + name string + image string + wantInode inodeProps + } + + tests := []rootDirTest{ + { + name: "ext4 root dir", + image: ext4ImagePath, + wantInode: inodeProps{ + Mode: linux.ModeDirectory | 0755, + Size: 0x400, + InodeSize: 0x80, + Links: 3, + Flags: disklayout.InodeFlags{Extents: true}, + }, + }, + { + name: "ext3 root dir", + image: ext3ImagePath, + wantInode: inodeProps{ + Mode: linux.ModeDirectory | 0755, + Size: 0x400, + InodeSize: 0x80, + Links: 3, + }, + }, + { + name: "ext2 root dir", + image: ext2ImagePath, + wantInode: inodeProps{ + Mode: linux.ModeDirectory | 0755, + Size: 0x400, + InodeSize: 0x80, + Links: 3, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, _, vfsd, tearDown, err := setUp(t, test.image) + if err != nil { + t.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + d, ok := vfsd.Impl().(*dentry) + if !ok { + t.Fatalf("ext dentry of incorrect type: %T", vfsd.Impl()) + } + + // Offload inode contents into local structs for comparison. + gotInode := inodeProps{ + Mode: d.inode.diskInode.Mode(), + UID: d.inode.diskInode.UID(), + GID: d.inode.diskInode.GID(), + Size: d.inode.diskInode.Size(), + InodeSize: d.inode.diskInode.InodeSize(), + Links: d.inode.diskInode.LinksCount(), + Flags: d.inode.diskInode.Flags(), + } + + if diff := cmp.Diff(gotInode, test.wantInode); diff != "" { + t.Errorf("inode mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// TestFilesystemInit tests that the filesystem superblock and block group +// descriptors are correctly read in and initialized. +func TestFilesystemInit(t *testing.T) { + // sb only contains the immutable properties of the superblock. + type sb struct { + InodesCount uint32 + BlocksCount uint64 + MaxMountCount uint16 + FirstDataBlock uint32 + BlockSize uint64 + BlocksPerGroup uint32 + ClusterSize uint64 + ClustersPerGroup uint32 + InodeSize uint16 + InodesPerGroup uint32 + BgDescSize uint16 + Magic uint16 + Revision disklayout.SbRevision + CompatFeatures disklayout.CompatFeatures + IncompatFeatures disklayout.IncompatFeatures + RoCompatFeatures disklayout.RoCompatFeatures + } + + // bg only contains the immutable properties of the block group descriptor. + type bg struct { + InodeTable uint64 + BlockBitmap uint64 + InodeBitmap uint64 + ExclusionBitmap uint64 + Flags disklayout.BGFlags + } + + type fsInitTest struct { + name string + image string + wantSb sb + wantBgs []bg + } + + tests := []fsInitTest{ + { + name: "ext4 filesystem init", + image: ext4ImagePath, + wantSb: sb{ + InodesCount: 0x10, + BlocksCount: 0x40, + MaxMountCount: 0xffff, + FirstDataBlock: 0x1, + BlockSize: 0x400, + BlocksPerGroup: 0x2000, + ClusterSize: 0x400, + ClustersPerGroup: 0x2000, + InodeSize: 0x80, + InodesPerGroup: 0x10, + BgDescSize: 0x40, + Magic: linux.EXT_SUPER_MAGIC, + Revision: disklayout.DynamicRev, + CompatFeatures: disklayout.CompatFeatures{ + ExtAttr: true, + ResizeInode: true, + DirIndex: true, + }, + IncompatFeatures: disklayout.IncompatFeatures{ + DirentFileType: true, + Extents: true, + Is64Bit: true, + FlexBg: true, + }, + RoCompatFeatures: disklayout.RoCompatFeatures{ + Sparse: true, + LargeFile: true, + HugeFile: true, + DirNlink: true, + ExtraIsize: true, + MetadataCsum: true, + }, + }, + wantBgs: []bg{ + { + InodeTable: 0x23, + BlockBitmap: 0x3, + InodeBitmap: 0x13, + Flags: disklayout.BGFlags{ + InodeZeroed: true, + }, + }, + }, + }, + { + name: "ext3 filesystem init", + image: ext3ImagePath, + wantSb: sb{ + InodesCount: 0x10, + BlocksCount: 0x40, + MaxMountCount: 0xffff, + FirstDataBlock: 0x1, + BlockSize: 0x400, + BlocksPerGroup: 0x2000, + ClusterSize: 0x400, + ClustersPerGroup: 0x2000, + InodeSize: 0x80, + InodesPerGroup: 0x10, + BgDescSize: 0x20, + Magic: linux.EXT_SUPER_MAGIC, + Revision: disklayout.DynamicRev, + CompatFeatures: disklayout.CompatFeatures{ + ExtAttr: true, + ResizeInode: true, + DirIndex: true, + }, + IncompatFeatures: disklayout.IncompatFeatures{ + DirentFileType: true, + }, + RoCompatFeatures: disklayout.RoCompatFeatures{ + Sparse: true, + LargeFile: true, + }, + }, + wantBgs: []bg{ + { + InodeTable: 0x5, + BlockBitmap: 0x3, + InodeBitmap: 0x4, + Flags: disklayout.BGFlags{ + InodeZeroed: true, + }, + }, + }, + }, + { + name: "ext2 filesystem init", + image: ext2ImagePath, + wantSb: sb{ + InodesCount: 0x10, + BlocksCount: 0x40, + MaxMountCount: 0xffff, + FirstDataBlock: 0x1, + BlockSize: 0x400, + BlocksPerGroup: 0x2000, + ClusterSize: 0x400, + ClustersPerGroup: 0x2000, + InodeSize: 0x80, + InodesPerGroup: 0x10, + BgDescSize: 0x20, + Magic: linux.EXT_SUPER_MAGIC, + Revision: disklayout.DynamicRev, + CompatFeatures: disklayout.CompatFeatures{ + ExtAttr: true, + ResizeInode: true, + DirIndex: true, + }, + IncompatFeatures: disklayout.IncompatFeatures{ + DirentFileType: true, + }, + RoCompatFeatures: disklayout.RoCompatFeatures{ + Sparse: true, + LargeFile: true, + }, + }, + wantBgs: []bg{ + { + InodeTable: 0x5, + BlockBitmap: 0x3, + InodeBitmap: 0x4, + Flags: disklayout.BGFlags{ + InodeZeroed: true, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, vfsfs, _, tearDown, err := setUp(t, test.image) + if err != nil { + t.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + fs, ok := vfsfs.Impl().(*filesystem) + if !ok { + t.Fatalf("ext filesystem of incorrect type: %T", vfsfs.Impl()) + } + + // Offload superblock and block group descriptors contents into + // local structs for comparison. + totalFreeInodes := uint32(0) + totalFreeBlocks := uint64(0) + gotSb := sb{ + InodesCount: fs.sb.InodesCount(), + BlocksCount: fs.sb.BlocksCount(), + MaxMountCount: fs.sb.MaxMountCount(), + FirstDataBlock: fs.sb.FirstDataBlock(), + BlockSize: fs.sb.BlockSize(), + BlocksPerGroup: fs.sb.BlocksPerGroup(), + ClusterSize: fs.sb.ClusterSize(), + ClustersPerGroup: fs.sb.ClustersPerGroup(), + InodeSize: fs.sb.InodeSize(), + InodesPerGroup: fs.sb.InodesPerGroup(), + BgDescSize: fs.sb.BgDescSize(), + Magic: fs.sb.Magic(), + Revision: fs.sb.Revision(), + CompatFeatures: fs.sb.CompatibleFeatures(), + IncompatFeatures: fs.sb.IncompatibleFeatures(), + RoCompatFeatures: fs.sb.ReadOnlyCompatibleFeatures(), + } + gotNumBgs := len(fs.bgs) + gotBgs := make([]bg, gotNumBgs) + for i := 0; i < gotNumBgs; i++ { + gotBgs[i].InodeTable = fs.bgs[i].InodeTable() + gotBgs[i].BlockBitmap = fs.bgs[i].BlockBitmap() + gotBgs[i].InodeBitmap = fs.bgs[i].InodeBitmap() + gotBgs[i].ExclusionBitmap = fs.bgs[i].ExclusionBitmap() + gotBgs[i].Flags = fs.bgs[i].Flags() + + totalFreeInodes += fs.bgs[i].FreeInodesCount() + totalFreeBlocks += uint64(fs.bgs[i].FreeBlocksCount()) + } + + if diff := cmp.Diff(gotSb, test.wantSb); diff != "" { + t.Errorf("superblock mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(gotBgs, test.wantBgs); diff != "" { + t.Errorf("block group descriptors mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(totalFreeInodes, fs.sb.FreeInodesCount()); diff != "" { + t.Errorf("total free inodes mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(totalFreeBlocks, fs.sb.FreeBlocksCount()); diff != "" { + t.Errorf("total free blocks mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/sentry/fs/ext/filesystem.go b/pkg/sentry/fs/ext/filesystem.go new file mode 100644 index 000000000..7150e75a5 --- /dev/null +++ b/pkg/sentry/fs/ext/filesystem.go @@ -0,0 +1,137 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "io" + "sync" + + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// filesystem implements vfs.FilesystemImpl. +type filesystem struct { + // TODO(b/134676337): Remove when all methods have been implemented. + vfs.FilesystemImpl + + vfsfs vfs.Filesystem + + // mu serializes changes to the Dentry tree and the usage of the read seeker. + mu sync.Mutex + + // dev is the ReadSeeker for the underlying fs device. It is protected by mu. + // + // The ext filesystems aim to maximize locality, i.e. place all the data + // blocks of a file close together. On a spinning disk, locality reduces the + // amount of movement of the head hence speeding up IO operations. On an SSD + // there are no moving parts but locality increases the size of each transer + // request. Hence, having mutual exclusion on the read seeker while reading a + // file *should* help in achieving the intended performance gains. + // + // Note: This synchronization was not coupled with the ReadSeeker itself + // because we want to synchronize across read/seek operations for the + // performance gains mentioned above. Helps enforcing one-file-at-a-time IO. + dev io.ReadSeeker + + // inodeCache maps absolute inode numbers to the corresponding Inode struct. + // Inodes should be removed from this once their reference count hits 0. + // + // Protected by mu because every addition and removal from this corresponds to + // a change in the dentry tree. + inodeCache map[uint32]*inode + + // sb represents the filesystem superblock. Immutable after initialization. + sb disklayout.SuperBlock + + // bgs represents all the block group descriptors for the filesystem. + // Immutable after initialization. + bgs []disklayout.BlockGroup +} + +// Compiles only if filesystem implements vfs.FilesystemImpl. +var _ vfs.FilesystemImpl = (*filesystem)(nil) + +// getOrCreateInode gets the inode corresponding to the inode number passed in. +// It creates a new one with the given inode number if one does not exist. +// +// Preconditions: must be holding fs.mu. +func (fs *filesystem) getOrCreateInode(inodeNum uint32) (*inode, error) { + if in, ok := fs.inodeCache[inodeNum]; ok { + return in, nil + } + + in, err := newInode(fs.dev, fs.sb, fs.bgs, inodeNum) + if err != nil { + return nil, err + } + + fs.inodeCache[inodeNum] = in + return in, nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { +} + +// Sync implements vfs.FilesystemImpl.Sync. +func (fs *filesystem) Sync(ctx context.Context) error { + // This is a readonly filesystem for now. + return nil +} + +// The vfs.FilesystemImpl functions below return EROFS because their respective +// man pages say that EROFS must be returned if the path resolves to a file on +// a read-only filesystem. + +// TODO(b/134676337): Implement path traversal and return EROFS only if the +// path resolves to a Dentry within ext fs. + +// MkdirAt implements vfs.FilesystemImpl.MkdirAt. +func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { + return syserror.EROFS +} + +// MknodAt implements vfs.FilesystemImpl.MknodAt. +func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { + return syserror.EROFS +} + +// RenameAt implements vfs.FilesystemImpl.RenameAt. +func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { + return syserror.EROFS +} + +// RmdirAt implements vfs.FilesystemImpl.RmdirAt. +func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { + return syserror.EROFS +} + +// SetStatAt implements vfs.FilesystemImpl.SetStatAt. +func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { + return syserror.EROFS +} + +// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. +func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { + return syserror.EROFS +} + +// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. +func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { + return syserror.EROFS +} diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index e0602da17..246b97161 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -571,12 +571,6 @@ func overlayCheck(ctx context.Context, o *overlayEntry, p PermMask) error { if o.upper != nil { err = o.upper.check(ctx, p) } else { - if p.Write { - // Since writes will be redirected to the upper filesystem, the lower - // filesystem need not be writable, but must be readable for copy-up. - p.Write = false - p.Read = true - } err = o.lower.check(ctx, p) } o.copyMu.RUnlock() diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 70f5a3f0b..4c2d48e65 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -240,6 +240,9 @@ type InitKernelArgs struct { // RootAbstractSocketNamespace is the root Abstract Socket namespace. RootAbstractSocketNamespace *AbstractSocketNamespace + + // PIDNamespace is the root PID namespace. + PIDNamespace *PIDNamespace } // Init initialize the Kernel with no tasks. @@ -262,7 +265,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.featureSet = args.FeatureSet k.timekeeper = args.Timekeeper - k.tasks = newTaskSet() + k.tasks = newTaskSet(args.PIDNamespace) k.rootUserNamespace = args.RootUserNamespace k.rootUTSNamespace = args.RootUTSNamespace k.rootIPCNamespace = args.RootIPCNamespace diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index b21b182fc..8267929a6 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -90,9 +90,9 @@ type TaskSet struct { } // newTaskSet returns a new, empty TaskSet. -func newTaskSet() *TaskSet { - ts := &TaskSet{} - ts.Root = newPIDNamespace(ts, nil /* parent */, auth.NewRootUserNamespace()) +func newTaskSet(pidns *PIDNamespace) *TaskSet { + ts := &TaskSet{Root: pidns} + pidns.owner = ts return ts } @@ -186,6 +186,12 @@ func newPIDNamespace(ts *TaskSet, parent *PIDNamespace, userns *auth.UserNamespa } } +// NewRootPIDNamespace creates the root PID namespace. 'owner' is not available +// yet when root namespace is created and must be set by caller. +func NewRootPIDNamespace(userns *auth.UserNamespace) *PIDNamespace { + return newPIDNamespace(nil, nil, userns) +} + // NewChild returns a new, empty PID namespace that is a child of ns. Authority // over the new PID namespace is controlled by userns. func (ns *PIDNamespace) NewChild(userns *auth.UserNamespace) *PIDNamespace { diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go index 7eef19f74..8fe489c0e 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/epsocket/stack.go @@ -75,8 +75,8 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { addrs = append(addrs, inet.InterfaceAddr{ Family: family, - PrefixLen: uint8(len(a.Address) * 8), - Addr: []byte(a.Address), + PrefixLen: uint8(a.AddressWithPrefix.PrefixLen), + Addr: []byte(a.AddressWithPrefix.Address), // TODO(b/68878065): Other fields. }) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cb35635fc..60070874d 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -46,7 +46,6 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { nicid tcpip.NICID - addr tcpip.Address linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache } @@ -73,6 +72,10 @@ func (e *endpoint) ID() *stack.NetworkEndpointID { return &stack.NetworkEndpointID{ProtocolAddress} } +func (e *endpoint) PrefixLen() int { + return 0 +} + func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.ARPSize } @@ -122,19 +125,19 @@ type protocol struct { func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } func (p *protocol) MinimumPacketSize() int { return header.ARPSize } +func (p *protocol) DefaultPrefixLen() int { return 0 } func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.ARP(v) return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { - if addr != ProtocolAddress { +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } return &endpoint{ nicid: nicid, - addr: addr, linkEP: sender, linkAddrCache: linkAddrCache, }, nil diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 8ff428445..55e9eec99 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -29,16 +29,18 @@ import ( ) const ( - localIpv4Addr = "\x0a\x00\x00\x01" - remoteIpv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + localIpv4Addr = "\x0a\x00\x00\x01" + localIpv4PrefixLen = 24 + remoteIpv4Addr = "\x0a\x00\x00\x02" + ipv4SubnetAddr = "\x0a\x00\x00\x00" + ipv4SubnetMask = "\xff\xff\xff\x00" + ipv4Gateway = "\x0a\x00\x00\x03" + localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + localIpv6PrefixLen = 120 + remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" + ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" ) // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. @@ -197,7 +199,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, nil, &o) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -229,7 +231,7 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -295,7 +297,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -359,7 +361,7 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -426,7 +428,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, nil, &o) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -458,7 +460,7 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -532,7 +534,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e44a73d96..b7a06f525 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -49,16 +49,18 @@ const ( type endpoint struct { nicid tcpip.NICID id stack.NetworkEndpointID + prefixLen int linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ nicid: nicid, - id: stack.NetworkEndpointID{LocalAddress: addr}, + id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), @@ -93,6 +95,11 @@ func (e *endpoint) ID() *stack.NetworkEndpointID { return &e.id } +// PrefixLen returns the ipv4 endpoint subnet prefix length in bits. +func (e *endpoint) PrefixLen() int { + return e.prefixLen +} + // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { @@ -338,6 +345,11 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv4MinimumSize } +// DefaultPrefixLen returns the IPv4 default prefix length. +func (p *protocol) DefaultPrefixLen() int { + return header.IPv4AddressSize * 8 +} + // ParseAddresses implements NetworkProtocol.ParseAddresses. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv4(v) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d46d68e73..726362c87 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -99,7 +99,11 @@ func TestICMPCounts(t *testing.T) { }}, ) - ep, err := s.NetworkProtocolInstance(ProtocolNumber).NewEndpoint(0, lladdr1, &stubLinkAddressCache{}, &stubDispatcher{}, nil) + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil) if err != nil { t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e3e8739fd..331a8bdaa 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -46,6 +46,7 @@ const ( type endpoint struct { nicid tcpip.NICID id stack.NetworkEndpointID + prefixLen int linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher @@ -72,6 +73,11 @@ func (e *endpoint) ID() *stack.NetworkEndpointID { return &e.id } +// PrefixLen returns the ipv6 endpoint subnet prefix length in bits. +func (e *endpoint) PrefixLen() int { + return e.prefixLen +} + // Capabilities implements stack.NetworkEndpoint.Capabilities. func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.linkEP.Capabilities() @@ -172,6 +178,11 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } +// DefaultPrefixLen returns the IPv6 default prefix length. +func (p *protocol) DefaultPrefixLen() int { + return header.IPv6AddressSize * 8 +} + // ParseAddresses implements NetworkProtocol.ParseAddresses. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) @@ -179,10 +190,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ nicid: nicid, - id: stack.NetworkEndpointID{LocalAddress: addr}, + id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 30c0dee42..3e6ff4afb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -129,7 +129,7 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) { +func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { n.mu.RLock() defer n.mu.RUnlock() @@ -148,21 +148,16 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add } if r == nil { - return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress + return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress } - address := r.ep.ID().LocalAddress + addressWithPrefix := tcpip.AddressWithPrefix{ + Address: r.ep.ID().LocalAddress, + PrefixLen: r.ep.PrefixLen(), + } r.decRef() - // Find the least-constrained matching subnet for the address, if one - // exists, and return it. - var subnet tcpip.Subnet - for _, s := range n.subnets { - if s.Contains(address) && !subnet.Contains(s.ID()) { - subnet = s - } - } - return address, subnet, nil + return addressWithPrefix, nil } // primaryEndpoint returns the primary endpoint of n for the given network @@ -213,23 +208,26 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A n.mu.Lock() ref = n.endpoints[id] if ref == nil || !ref.tryIncRef() { - ref, _ = n.addAddressLocked(protocol, address, peb, true) - if ref != nil { - ref.holdsInsertRef = false + if netProto, ok := n.stack.networkProtocols[protocol]; ok { + addrWithPrefix := tcpip.AddressWithPrefix{address, netProto.DefaultPrefixLen()} + ref, _ = n.addAddressLocked(protocol, addrWithPrefix, peb, true) + if ref != nil { + ref.holdsInsertRef = false + } } } n.mu.Unlock() return ref } -func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { netProto, ok := n.stack.networkProtocols[protocol] if !ok { return nil, tcpip.ErrUnknownProtocol } // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP) + ep, err := netProto.NewEndpoint(n.id, addrWithPrefix, n.stack, n, n.linkEP) if err != nil { return nil, err } @@ -278,16 +276,10 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip. // AddAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { - return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint) -} - -// AddAddressWithOptions is the same as AddAddress, but allows you to specify -// whether the new endpoint can be primary or not. -func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { +func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocol, addr, peb, false) + _, err := n.addAddressLocked(protocol, addrWithPrefix, peb, false) n.mu.Unlock() return err @@ -298,10 +290,13 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) - for nid, ep := range n.endpoints { + for nid, ref := range n.endpoints { addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: ep.protocol, - Address: nid.LocalAddress, + Protocol: ref.protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: nid.LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }, }) } return addrs @@ -415,7 +410,12 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { - if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil { + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + return tcpip.ErrUnknownProtocol + } + addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()} + if _, err := n.addAddressLocked(protocol, addrWithPrefix, NeverPrimaryEndpoint, false); err != nil { return err } } @@ -572,7 +572,13 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r n.mu.Unlock() return ref } - ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true) + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + n.mu.Unlock() + return nil + } + addrWithPrefix := tcpip.AddressWithPrefix{dst, netProto.DefaultPrefixLen()} + ref, err := n.addAddressLocked(protocol, addrWithPrefix, CanBePrimaryEndpoint, true) n.mu.Unlock() if err == nil { ref.holdsInsertRef = false diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 462265281..2037eef9f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -181,6 +181,9 @@ type NetworkEndpoint interface { // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID + // PrefixLen returns the network endpoint's subnet prefix length in bits. + PrefixLen() int + // NICID returns the id of the NIC this endpoint belongs to. NICID() tcpip.NICID @@ -203,12 +206,15 @@ type NetworkProtocol interface { // than this targeted at this protocol. MinimumPacketSize() int + // DefaultPrefixLen returns the protocol's default prefix length. + DefaultPrefixLen() int + // ParsePorts returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3e8fb2a6c..57b8a9994 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -751,9 +751,26 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } +// AddAddressWithPrefix adds a new network-layer address/prefixLen to the +// specified NIC. +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix) *tcpip.Error { + return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, CanBePrimaryEndpoint) +} + // AddAddressWithOptions is the same as AddAddress, but allows you to specify // whether the new endpoint can be primary or not. func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { + netProto, ok := s.networkProtocols[protocol] + if !ok { + return tcpip.ErrUnknownProtocol + } + addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()} + return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, peb) +} + +// AddAddressWithPrefixAndOptions is the same as AddAddressWithPrefixLen, +// but allows you to specify whether the new endpoint can be primary or not. +func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -762,7 +779,7 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt return tcpip.ErrUnknownNICID } - return nic.AddAddressWithOptions(protocol, addr, peb) + return nic.AddAddress(protocol, addrWithPrefix, peb) } // AddSubnet adds a subnet range to the specified NIC. @@ -821,7 +838,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { // contains it) for the given NIC and protocol. Returns an arbitrary endpoint's // address if no primary addresses exist. Returns an error if the NIC doesn't // exist or has no endpoints. -func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) { +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -829,7 +846,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.getMainNICAddress(protocol) } - return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID + return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 959071dbe..9d082bba4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "sort" "strings" "testing" @@ -32,8 +33,9 @@ import ( ) const ( - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fakeNetHeaderLen = 12 + fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fakeNetHeaderLen = 12 + fakeDefaultPrefixLen = 8 // fakeControlProtocol is used for control packets that represent // destination port unreachable. @@ -55,6 +57,7 @@ const ( type fakeNetworkEndpoint struct { nicid tcpip.NICID id stack.NetworkEndpointID + prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher linkEP stack.LinkEndpoint @@ -68,6 +71,10 @@ func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { return f.nicid } +func (f *fakeNetworkEndpoint) PrefixLen() int { + return f.prefixLen +} + func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } @@ -170,14 +177,19 @@ func (f *fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } +func (f *fakeNetworkProtocol) DefaultPrefixLen() int { + return fakeDefaultPrefixLen +} + func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, - id: stack.NetworkEndpointID{addr}, + id: stack.NetworkEndpointID{addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, linkEP: linkEP, @@ -212,15 +224,15 @@ func TestNetworkReceive(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -280,13 +292,13 @@ func TestNetworkReceive(t *testing.T) { func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatal("FindRoute failed:", err) } defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Errorf("WritePacket failed: %v", err) + t.Error("WritePacket failed:", err) } } @@ -297,13 +309,13 @@ func TestNetworkSend(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("NewNIC failed: %v", err) + t.Fatal("NewNIC failed:", err) } s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } // Make sure that the link-layer endpoint received the outbound packet. @@ -321,28 +333,28 @@ func TestNetworkSendMultiRoute(t *testing.T) { id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } id2, linkEP2 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } // Set a route table that sends all packets with odd destination @@ -371,7 +383,7 @@ func TestNetworkSendMultiRoute(t *testing.T) { func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatal("FindRoute failed:", err) } defer r.Release() @@ -388,7 +400,7 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err) + t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute) } } @@ -400,28 +412,28 @@ func TestRoutes(t *testing.T) { id1, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } id2, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } // Set a route table that sends all packets with odd destination @@ -464,11 +476,11 @@ func TestAddressRemoval(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -486,7 +498,7 @@ func TestAddressRemoval(t *testing.T) { // Remove the address, then check that packet doesn't get delivered // anymore. if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + t.Fatal("RemoveAddress failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -496,7 +508,7 @@ func TestAddressRemoval(t *testing.T) { // Check that removing the same address fails. if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress failed: %v", err) + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } } @@ -505,11 +517,11 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -531,7 +543,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { // Get a route, check that packet is still deliverable. r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatal("FindRoute failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -542,7 +554,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { // Remove the address, then check that packet is still deliverable // because the route is keeping the address alive. if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + t.Fatal("RemoveAddress failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -552,7 +564,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { // Check that removing the same address fails. if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress failed: %v", err) + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } // Release the route, then check that packet is not deliverable anymore. @@ -568,7 +580,7 @@ func TestPromiscuousMode(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -590,7 +602,7 @@ func TestPromiscuousMode(t *testing.T) { // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { - t.Fatalf("SetPromiscuousMode failed: %v", err) + t.Fatal("SetPromiscuousMode failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -601,13 +613,13 @@ func TestPromiscuousMode(t *testing.T) { // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err) + t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute) } // Set promiscuous mode to false, then check that packet can't be // delivered anymore. if err := s.SetPromiscuousMode(1, false); err != nil { - t.Fatalf("SetPromiscuousMode failed: %v", err) + t.Fatal("SetPromiscuousMode failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -624,11 +636,11 @@ func TestAddressSpoofing(t *testing.T) { id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -645,11 +657,11 @@ func TestAddressSpoofing(t *testing.T) { // With address spoofing enabled, FindRoute permits any address to be used // as the source. if err := s.SetSpoofing(1, true); err != nil { - t.Fatalf("SetSpoofing failed: %v", err) + t.Fatal("SetSpoofing failed:", err) } r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatal("FindRoute failed:", err) } if r.LocalAddress != srcAddr { t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) @@ -664,17 +676,17 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, header.IPv4Any, err) + t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { @@ -735,7 +747,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) @@ -791,7 +803,7 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -806,10 +818,10 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { fakeNet.packetCount[1] = 0 subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + t.Fatal("NewSubnet failed:", err) } if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + t.Fatal("AddSubnet failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) @@ -824,7 +836,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{ @@ -839,10 +851,10 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { fakeNet.packetCount[1] = 0 subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + t.Fatal("NewSubnet failed:", err) } if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + t.Fatal("AddSubnet failed:", err) } linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { @@ -894,38 +906,38 @@ func TestSubnetAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } addr := tcpip.Address("\x01\x01\x01\x01") mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) subnet, err := tcpip.NewSubnet(addr, mask) if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + t.Fatal("NewSubnet failed:", err) } if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) + t.Fatal("ContainsSubnet failed:", err) } else if contained { t.Fatal("got s.ContainsSubnet(...) = true, want = false") } if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + t.Fatal("AddSubnet failed:", err) } if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) + t.Fatal("ContainsSubnet failed:", err) } else if !contained { t.Fatal("got s.ContainsSubnet(...) = false, want = true") } if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatalf("RemoveSubnet failed: %v", err) + t.Fatal("RemoveSubnet failed:", err) } if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) + t.Fatal("ContainsSubnet failed:", err) } else if contained { t.Fatal("got s.ContainsSubnet(...) = true, want = false") } @@ -941,11 +953,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } // Insert <canBe> primary and <never> never-primary addresses. // Each one will add a network endpoint to the NIC. - primaryAddrAdded := make(map[tcpip.Address]tcpip.Subnet) + primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{}) for i := 0; i < canBe+never; i++ { var behavior stack.PrimaryEndpointBehavior if i < canBe { @@ -953,46 +965,39 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } else { behavior = stack.NeverPrimaryEndpoint } - // Add an address and in case of a primary one also add a - // subnet. + // Add an address and in case of a primary one include a + // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) - if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) - } if behavior == stack.CanBePrimaryEndpoint { - mask := tcpip.AddressMask(strings.Repeat("\xff", len(address))) - subnet, err := tcpip.NewSubnet(address, mask) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + addressWithPrefix := tcpip.AddressWithPrefix{address, addrLen * 8} + if err := s.AddAddressWithPrefixAndOptions(1, fakeNetNumber, addressWithPrefix, behavior); err != nil { + t.Fatal("AddAddressWithPrefixAndOptions failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + // Remember the address/prefix. + primaryAddrAdded[addressWithPrefix] = struct{}{} + } else { + if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) } - // Remember the address/subnet. - primaryAddrAdded[address] = subnet } } // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the - // address/subnet matches what we added. + // address/prefixLen matches what we added. if len(primaryAddrAdded) == 0 { // No primary addresses present, expect an error. - if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %v", err, tcpip.ErrNoLinkAddress) + if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress) } } else { // At least one primary address was added, expect a valid - // address and subnet. - gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber) + // address and prefixLen. + gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber) if err != nil { - t.Fatalf("GetMainNICAddress failed: %v", err) - } - expectedSubnet, ok := primaryAddrAdded[gotAddress] - if !ok { - t.Fatalf("GetMainNICAddress: got address = %v, wanted any in {%v}", gotAddress, primaryAddrAdded) + t.Fatal("GetMainNICAddress failed:", err) } - if gotSubnet != expectedSubnet { - t.Fatalf("GetMainNICAddress: got subnet = %v, wanted %v", gotSubnet, expectedSubnet) + if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok { + t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded) } } }) @@ -1007,65 +1012,194 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } for _, tc := range []struct { - name string - address tcpip.Address + name string + address tcpip.Address + prefixLen int }{ - {"IPv4", "\x01\x01\x01\x01"}, - {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, + {"IPv4", "\x01\x01\x01\x01", 24}, + {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, } { t.Run(tc.name, func(t *testing.T) { - address := tc.address - mask := tcpip.AddressMask(strings.Repeat("\xff", len(address))) - subnet, err := tcpip.NewSubnet(address, mask) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) + addressWithPrefix := tcpip.AddressWithPrefix{tc.address, tc.prefixLen} + + if err := s.AddAddressWithPrefix(1, fakeNetNumber, addressWithPrefix); err != nil { + t.Fatal("AddAddressWithPrefix failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress failed: %v", err) + // Check that we get the right initial address and prefix length. + if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } else if gotAddressWithPrefix != addressWithPrefix { + t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, addressWithPrefix) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) + if err := s.RemoveAddress(1, addressWithPrefix.Address); err != nil { + t.Fatal("RemoveAddress failed:", err) } - // Check that we get the right initial address and subnet. - if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress failed: %v", err) - } else if gotAddress != address { - t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address) - } else if gotSubnet != subnet { - t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet) + // Check that we get an error after removal. + if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress) } + }) + } +} + +// Simple network address generator. Good for 255 addresses. +type addressGenerator struct{ cnt byte } + +func (g *addressGenerator) next(addrLen int) tcpip.Address { + g.cnt++ + return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen)) +} + +func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { + if len(gotAddresses) != len(expectedAddresses) { + t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses)) + } + + sort.Slice(gotAddresses, func(i, j int) bool { + return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address + }) + sort.Slice(expectedAddresses, func(i, j int) bool { + return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address + }) + + for i, gotAddr := range gotAddresses { + expectedAddr := expectedAddresses[i] + if gotAddr != expectedAddr { + t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr) + } + } +} + +func TestAddAddress(t *testing.T) { + const nicid = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatalf("RemoveSubnet failed: %v", err) + var addrGen addressGenerator + expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) + for _, addrLen := range []int{4, 16} { + address := addrGen.next(addrLen) + if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + t.Fatalf("AddAddress(address=%s) failed: %s", address, err) + } + expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + }) + } + + gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + verifyAddresses(t, expectedAddresses, gotAddresses) +} + +func TestAddAddressWithPrefix(t *testing.T) { + const nicid = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + var addrGen addressGenerator + addrLenRange := []int{4, 16} + prefixLenRange := []int{8, 13, 20, 32} + expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) + for _, addrLen := range addrLenRange { + for _, prefixLen := range prefixLenRange { + address := addrGen.next(addrLen) + if err := s.AddAddressWithPrefix(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}); err != nil { + t.Errorf("AddAddressWithPrefix(address=%s, prefixLen=%d) failed: %s", address, prefixLen, err) } + expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen}, + }) + } + } - if err := s.RemoveAddress(1, address); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + verifyAddresses(t, expectedAddresses, gotAddresses) +} + +func TestAddAddressWithOptions(t *testing.T) { + const nicid = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + addrLenRange := []int{4, 16} + behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} + expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) + var addrGen addressGenerator + for _, addrLen := range addrLenRange { + for _, behavior := range behaviorRange { + address := addrGen.next(addrLen) + if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } + expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + }) + } + } - // Check that we get an error after removal. - if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress) + gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + verifyAddresses(t, expectedAddresses, gotAddresses) +} + +func TestAddAddressWithPrefixAndOptions(t *testing.T) { + const nicid = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + addrLenRange := []int{4, 16} + prefixLenRange := []int{8, 13, 20, 32} + behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} + expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) + var addrGen addressGenerator + for _, addrLen := range addrLenRange { + for _, prefixLen := range prefixLenRange { + for _, behavior := range behaviorRange { + address := addrGen.next(addrLen) + if err := s.AddAddressWithPrefixAndOptions(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}, behavior); err != nil { + t.Fatalf("AddAddressWithPrefixAndOptions(address=%s, prefixLen=%d, behavior=%d) failed: %s", address, prefixLen, behavior, err) + } + expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen}, + }) } - }) + } } + + gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + verifyAddresses(t, expectedAddresses, gotAddresses) } func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatal("CreateNIC failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatal("AddAddress failed:", err) } // Route all packets for address \x01 to NIC 1. s.SetRouteTable([]tcpip.Route{ @@ -1104,18 +1238,18 @@ func TestNICForwarding(t *testing.T) { id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) + t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) + t.Fatal("AddAddress #1 failed:", err) } id2, linkEP2 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) + t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) + t.Fatal("AddAddress #2 failed:", err) } // Route all packets to address 3 to NIC 2. diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c5d79da5e..4208c0303 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1059,14 +1059,23 @@ func ParseMACAddress(s string) (LinkAddress, error) { return LinkAddress(addr), nil } +// AddressWithPrefix is an address with its subnet prefix length. +type AddressWithPrefix struct { + // Address is a network address. + Address Address + + // PrefixLen is the subnet prefix length. + PrefixLen int +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { // Protocol is the protocol of the address. Protocol NetworkProtocolNumber - // Address is a network address. - Address Address + // AddressWithPrefix is a network address with its subnet prefix length. + AddressWithPrefix AddressWithPrefix } // danglingEndpointsMu protects access to danglingEndpoints. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 958d5712e..56c285f88 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -967,14 +967,14 @@ func TestTTL(t *testing.T) { multicast = false switch variant { case "v4", "mapped": - ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) if err != nil { t.Fatal(err) } wantTTL = ep.DefaultTTL() ep.Close() case "v6": - ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) if err != nil { t.Fatal(err) } |