diff options
Diffstat (limited to 'pkg')
117 files changed, 3464 insertions, 1906 deletions
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go index ed3881e27..50e22fe7e 100644 --- a/pkg/abi/linux/ptrace_amd64.go +++ b/pkg/abi/linux/ptrace_amd64.go @@ -50,3 +50,14 @@ type PtraceRegs struct { Fs uint64 Gs uint64 } + +// InstructionPointer returns the address of the next instruction to +// be executed. +func (p *PtraceRegs) InstructionPointer() uint64 { + return p.Rip +} + +// StackPointer returns the address of the Stack pointer. +func (p *PtraceRegs) StackPointer() uint64 { + return p.Rsp +} diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go index 6147738b3..da36811d2 100644 --- a/pkg/abi/linux/ptrace_arm64.go +++ b/pkg/abi/linux/ptrace_arm64.go @@ -27,3 +27,14 @@ type PtraceRegs struct { Pc uint64 Pstate uint64 } + +// InstructionPointer returns the address of the next instruction to be +// executed. +func (p *PtraceRegs) InstructionPointer() uint64 { + return p.Pc +} + +// StackPointer returns the address of the Stack pointer. +func (p *PtraceRegs) StackPointer() uint64 { + return p.Sp +} diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 23ef7ea8d..3ed6aba5c 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -18,7 +18,6 @@ go_library( deps = [ "//pkg/linewriter", "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/log/log.go b/pkg/log/log.go index d39af3bf4..073cf6238 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -40,7 +40,6 @@ import ( "sync/atomic" "time" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/linewriter" "gvisor.dev/gvisor/pkg/sync" ) @@ -105,7 +104,7 @@ func (l *Writer) Write(data []byte) (int, error) { n += w // Is it a non-blocking socket? - if pathErr, ok := err.(*os.PathError); ok && pathErr.Err == unix.EAGAIN { + if pathErr, ok := err.(*os.PathError); ok && pathErr.Timeout() { runtime.Gosched() continue } diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 8e0aa9019..58deb25fc 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -303,17 +303,18 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { // Take a reference on the upper Inode (transferred to // next.Inode.overlay.upper) and make new translations use it. - next.Inode.overlay.dataMu.Lock() + overlay := next.Inode.overlay + overlay.dataMu.Lock() childUpperInode.IncRef() - next.Inode.overlay.upper = childUpperInode - next.Inode.overlay.dataMu.Unlock() + overlay.upper = childUpperInode + overlay.dataMu.Unlock() // Invalidate existing translations through the lower Inode. - next.Inode.overlay.mappings.InvalidateAll(memmap.InvalidateOpts{}) + overlay.mappings.InvalidateAll(memmap.InvalidateOpts{}) // Remove existing memory mappings from the lower Inode. if lowerMappable != nil { - for seg := next.Inode.overlay.mappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + for seg := overlay.mappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { for m := range seg.Value() { lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable) } diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go index 141e3c27f..e2af1d2ae 100644 --- a/pkg/sentry/fs/gofer/inode_state.go +++ b/pkg/sentry/fs/gofer/inode_state.go @@ -109,6 +109,7 @@ func (i *inodeFileState) loadLoading(_ struct{}) { } // afterLoad is invoked by stateify. +// +checklocks:i.loading func (i *inodeFileState) afterLoad() { load := func() (err error) { // See comment on i.loading(). diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 52061175f..bbe282c03 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -17,6 +17,7 @@ package proc import ( "fmt" "io" + "math" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -26,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -498,6 +500,120 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) } +// portRangeInode implements fs.InodeOperations. It provides and allows +// modification of the range of ephemeral ports that IPv4 and IPv6 sockets +// choose from. +// +// +stateify savable +type portRangeInode struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + + // start and end store the port range. We must save/restore this here, + // since a netstack instance is created on restore. + start *uint16 + end *uint16 +} + +func newPortRangeInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ipf := &portRangeInode{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ipf, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. Truncate is called when +// O_TRUNC is specified for any kind of existing Dirent but is not called via +// (f)truncate for proc files. +func (*portRangeInode) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// +stateify savable +type portRangeFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + inode *portRangeInode +} + +// GetFile implements fs.InodeOperations.GetFile. +func (in *portRangeInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &portRangeFile{ + inode: in, + }), nil +} + +// Read implements fs.FileOperations.Read. +func (pf *portRangeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + if pf.inode.start == nil { + start, end := pf.inode.stack.PortRange() + pf.inode.start = &start + pf.inode.end = &end + } + + contents := fmt.Sprintf("%d %d\n", *pf.inode.start, *pf.inode.end) + n, err := dst.CopyOut(ctx, []byte(contents)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +// +// Offset is ignored, multiple writes are not supported. +func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Only consider size of one memory page for input for performance + // reasons. + src = src.TakeFirst(usermem.PageSize - 1) + + ports := make([]int32, 2) + n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts) + if err != nil { + return 0, err + } + + // Port numbers must be uint16s. + if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 { + return 0, syserror.EINVAL + } + + if err := pf.inode.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil { + return 0, err + } + if pf.inode.start == nil { + pf.inode.start = new(uint16) + pf.inode.end = new(uint16) + } + *pf.inode.start = uint16(ports[0]) + *pf.inode.end = uint16(ports[1]) + return n, nil +} + func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { contents := map[string]*fs.Inode{ // Add tcp_sack. @@ -506,12 +622,15 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine // Add ip_forward. "ip_forward": newIPForwardingInode(ctx, msrc, s), + // Allow for configurable ephemeral port ranges. Note that this + // controls ports for both IPv4 and IPv6 sockets. + "ip_local_port_range": newPortRangeInode(ctx, msrc, s), + // The following files are simple stubs until they are // implemented in netstack, most of these files are // configuration related. We use the value closest to the // actual netstack behavior or any empty file, all of these // files will have mode 0444 (read-only for all users). - "ip_local_port_range": newStaticProcInode(ctx, msrc, []byte("16000 65535")), "ip_local_reserved_ports": newStaticProcInode(ctx, msrc, []byte("")), "ipfrag_time": newStaticProcInode(ctx, msrc, []byte("30")), "ip_nonlocal_bind": newStaticProcInode(ctx, msrc, []byte("0")), diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index d8c237753..e75954105 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -137,6 +137,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.Release(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // rootInode is the root directory inode for the devpts mounts. // // +stateify savable diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 917f1873d..d4fc484a2 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -548,3 +548,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.mu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 204d8d143..fef857afb 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -47,19 +47,14 @@ type FilesystemType struct{} // +stateify savable type filesystemOptions struct { - // userID specifies the numeric uid of the mount owner. - // This option should not be specified by the filesystem owner. - // It is set by libfuse (or, if libfuse is not used, must be set - // by the filesystem itself). For more information, see man page - // for fuse(8) - userID uint32 - - // groupID specifies the numeric gid of the mount owner. - // This option should not be specified by the filesystem owner. - // It is set by libfuse (or, if libfuse is not used, must be set - // by the filesystem itself). For more information, see man page - // for fuse(8) - groupID uint32 + // mopts contains the raw, unparsed mount options passed to this filesystem. + mopts string + + // uid of the mount owner. + uid auth.KUID + + // gid of the mount owner. + gid auth.KGID // rootMode specifies the the file mode of the filesystem's root. rootMode linux.FileMode @@ -73,6 +68,19 @@ type filesystemOptions struct { // specified as "max_read" in fs parameters. // If not specified by user, use math.MaxUint32 as default value. maxRead uint32 + + // defaultPermissions is the default_permissions mount option. It instructs + // the kernel to perform a standard unix permission checks based on + // ownership and mode bits, instead of deferring the check to the server. + // + // Immutable after mount. + defaultPermissions bool + + // allowOther is the allow_other mount option. It allows processes that + // don't own the FUSE mount to call into it. + // + // Immutable after mount. + allowOther bool } // filesystem implements vfs.FilesystemImpl. @@ -108,18 +116,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } - var fsopts filesystemOptions + fsopts := filesystemOptions{mopts: opts.Data} mopts := vfs.GenericParseMountOptions(opts.Data) deviceDescriptorStr, ok := mopts["fd"] if !ok { - log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name()) + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option fd missing") return nil, nil, syserror.EINVAL } delete(mopts, "fd") deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) if err != nil { - log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err) + ctx.Debugf("fusefs.FilesystemType.GetFilesystem: invalid fd: %q (%v)", deviceDescriptorStr, err) return nil, nil, syserror.EINVAL } @@ -141,38 +149,54 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Parse and set all the other supported FUSE mount options. // TODO(gVisor.dev/issue/3229): Expand the supported mount options. - if userIDStr, ok := mopts["user_id"]; ok { + if uidStr, ok := mopts["user_id"]; ok { delete(mopts, "user_id") - userID, err := strconv.ParseUint(userIDStr, 10, 32) + uid, err := strconv.ParseUint(uidStr, 10, 32) if err != nil { - log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr) + log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), uidStr) return nil, nil, syserror.EINVAL } - fsopts.userID = uint32(userID) + kuid := creds.UserNamespace.MapToKUID(auth.UID(uid)) + if !kuid.Ok() { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped uid: %d", uid) + return nil, nil, syserror.EINVAL + } + fsopts.uid = kuid + } else { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option user_id missing") + return nil, nil, syserror.EINVAL } - if groupIDStr, ok := mopts["group_id"]; ok { + if gidStr, ok := mopts["group_id"]; ok { delete(mopts, "group_id") - groupID, err := strconv.ParseUint(groupIDStr, 10, 32) + gid, err := strconv.ParseUint(gidStr, 10, 32) if err != nil { - log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr) + log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), gidStr) + return nil, nil, syserror.EINVAL + } + kgid := creds.UserNamespace.MapToKGID(auth.GID(gid)) + if !kgid.Ok() { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped gid: %d", gid) return nil, nil, syserror.EINVAL } - fsopts.groupID = uint32(groupID) + fsopts.gid = kgid + } else { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option group_id missing") + return nil, nil, syserror.EINVAL } - rootMode := linux.FileMode(0777) - modeStr, ok := mopts["rootmode"] - if ok { + if modeStr, ok := mopts["rootmode"]; ok { delete(mopts, "rootmode") mode, err := strconv.ParseUint(modeStr, 8, 32) if err != nil { log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr) return nil, nil, syserror.EINVAL } - rootMode = linux.FileMode(mode) + fsopts.rootMode = linux.FileMode(mode) + } else { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option rootmode missing") + return nil, nil, syserror.EINVAL } - fsopts.rootMode = rootMode // Set the maxInFlightRequests option. fsopts.maxActiveRequests = maxActiveRequestsDefault @@ -192,6 +216,16 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.maxRead = math.MaxUint32 } + if _, ok := mopts["default_permissions"]; ok { + delete(mopts, "default_permissions") + fsopts.defaultPermissions = true + } + + if _, ok := mopts["allow_other"]; ok { + delete(mopts, "allow_other") + fsopts.allowOther = true + } + // Check for unparsed options. if len(mopts) != 0 { log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts) @@ -260,6 +294,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.Release(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return fs.opts.mopts +} + // inode implements kernfs.Inode. // // +stateify savable @@ -318,6 +357,37 @@ func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FU return i } +// CheckPermissions implements kernfs.Inode.CheckPermissions. +func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { + // Since FUSE operations are ultimately backed by a userspace process (the + // fuse daemon), allowing a process to call into fusefs grants the daemon + // ptrace-like capabilities over the calling process. Because of this, by + // default FUSE only allows the mount owner to interact with the + // filesystem. This explicitly excludes setuid/setgid processes. + // + // This behaviour can be overriden with the 'allow_other' mount option. + // + // See fs/fuse/dir.c:fuse_allow_current_process() in Linux. + if !i.fs.opts.allowOther { + if creds.RealKUID != i.fs.opts.uid || + creds.EffectiveKUID != i.fs.opts.uid || + creds.SavedKUID != i.fs.opts.uid || + creds.RealKGID != i.fs.opts.gid || + creds.EffectiveKGID != i.fs.opts.gid || + creds.SavedKGID != i.fs.opts.gid { + return syserror.EACCES + } + } + + // By default, fusefs delegates all permission checks to the server. + // However, standard unix permission checks can be enabled with the + // default_permissions mount option. + if i.fs.opts.defaultPermissions { + return i.InodeAttrs.CheckPermissions(ctx, creds, ats) + } + return nil +} + // Open implements kernfs.Inode.Open. func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { isDir := i.InodeAttrs.Mode().IsDir() diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 8f95473b6..c34451269 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -15,7 +15,9 @@ package gofer import ( + "fmt" "math" + "strings" "sync" "sync/atomic" @@ -1608,3 +1610,58 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +type mopt struct { + key string + value interface{} +} + +func (m mopt) String() string { + if m.value == nil { + return fmt.Sprintf("%s", m.key) + } + return fmt.Sprintf("%s=%v", m.key, m.value) +} + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + optsKV := []mopt{ + {moptTransport, transportModeFD}, // Only valid value, currently. + {moptReadFD, fs.opts.fd}, // Currently, read and write FD are the same. + {moptWriteFD, fs.opts.fd}, // Currently, read and write FD are the same. + {moptAname, fs.opts.aname}, + {moptDfltUID, fs.opts.dfltuid}, + {moptDfltGID, fs.opts.dfltgid}, + {moptMsize, fs.opts.msize}, + {moptVersion, fs.opts.version}, + {moptDentryCacheLimit, fs.opts.maxCachedDentries}, + } + + switch fs.opts.interop { + case InteropModeExclusive: + optsKV = append(optsKV, mopt{moptCache, cacheFSCache}) + case InteropModeWritethrough: + optsKV = append(optsKV, mopt{moptCache, cacheFSCacheWritethrough}) + case InteropModeShared: + if fs.opts.regularFilesUseSpecialFileFD { + optsKV = append(optsKV, mopt{moptCache, cacheNone}) + } else { + optsKV = append(optsKV, mopt{moptCache, cacheRemoteRevalidating}) + } + } + if fs.opts.forcePageCache { + optsKV = append(optsKV, mopt{moptForcePageCache, nil}) + } + if fs.opts.limitHostFDTranslation { + optsKV = append(optsKV, mopt{moptLimitHostFDTranslation, nil}) + } + if fs.opts.overlayfsStaleRead { + optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil}) + } + + opts := make([]string, 0, len(optsKV)) + for _, opt := range optsKV { + opts = append(opts, opt.String()) + } + return strings.Join(opts, ",") +} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 1508cbdf1..71569dc65 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -66,6 +66,34 @@ import ( // Name is the default filesystem name. const Name = "9p" +// Mount option names for goferfs. +const ( + moptTransport = "trans" + moptReadFD = "rfdno" + moptWriteFD = "wfdno" + moptAname = "aname" + moptDfltUID = "dfltuid" + moptDfltGID = "dfltgid" + moptMsize = "msize" + moptVersion = "version" + moptDentryCacheLimit = "dentry_cache_limit" + moptCache = "cache" + moptForcePageCache = "force_page_cache" + moptLimitHostFDTranslation = "limit_host_fd_translation" + moptOverlayfsStaleRead = "overlayfs_stale_read" +) + +// Valid values for the "cache" mount option. +const ( + cacheNone = "none" + cacheFSCache = "fscache" + cacheFSCacheWritethrough = "fscache_writethrough" + cacheRemoteRevalidating = "remote_revalidating" +) + +// Valid values for "trans" mount option. +const transportModeFD = "fd" + // FilesystemType implements vfs.FilesystemType. // // +stateify savable @@ -301,39 +329,39 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Get the attach name. fsopts.aname = "/" - if aname, ok := mopts["aname"]; ok { - delete(mopts, "aname") + if aname, ok := mopts[moptAname]; ok { + delete(mopts, moptAname) fsopts.aname = aname } // Parse the cache policy. For historical reasons, this defaults to the // least generally-applicable option, InteropModeExclusive. fsopts.interop = InteropModeExclusive - if cache, ok := mopts["cache"]; ok { - delete(mopts, "cache") + if cache, ok := mopts[moptCache]; ok { + delete(mopts, moptCache) switch cache { - case "fscache": + case cacheFSCache: fsopts.interop = InteropModeExclusive - case "fscache_writethrough": + case cacheFSCacheWritethrough: fsopts.interop = InteropModeWritethrough - case "none": + case cacheNone: fsopts.regularFilesUseSpecialFileFD = true fallthrough - case "remote_revalidating": + case cacheRemoteRevalidating: fsopts.interop = InteropModeShared default: - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: cache=%s", cache) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: %s=%s", moptCache, cache) return nil, nil, syserror.EINVAL } } // Parse the default UID and GID. fsopts.dfltuid = _V9FS_DEFUID - if dfltuidstr, ok := mopts["dfltuid"]; ok { - delete(mopts, "dfltuid") + if dfltuidstr, ok := mopts[moptDfltUID]; ok { + delete(mopts, moptDfltUID) dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltUID, dfltuidstr) return nil, nil, syserror.EINVAL } // In Linux, dfltuid is interpreted as a UID and is converted to a KUID @@ -342,11 +370,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.dfltuid = auth.KUID(dfltuid) } fsopts.dfltgid = _V9FS_DEFGID - if dfltgidstr, ok := mopts["dfltgid"]; ok { - delete(mopts, "dfltgid") + if dfltgidstr, ok := mopts[moptDfltGID]; ok { + delete(mopts, moptDfltGID) dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltGID, dfltgidstr) return nil, nil, syserror.EINVAL } fsopts.dfltgid = auth.KGID(dfltgid) @@ -354,11 +382,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Parse the 9P message size. fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M - if msizestr, ok := mopts["msize"]; ok { - delete(mopts, "msize") + if msizestr, ok := mopts[moptMsize]; ok { + delete(mopts, moptMsize) msize, err := strconv.ParseUint(msizestr, 10, 32) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: msize=%s", msizestr) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: %s=%s", moptMsize, msizestr) return nil, nil, syserror.EINVAL } fsopts.msize = uint32(msize) @@ -366,34 +394,34 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Parse the 9P protocol version. fsopts.version = p9.HighestVersionString() - if version, ok := mopts["version"]; ok { - delete(mopts, "version") + if version, ok := mopts[moptVersion]; ok { + delete(mopts, moptVersion) fsopts.version = version } // Parse the dentry cache limit. fsopts.maxCachedDentries = 1000 - if str, ok := mopts["dentry_cache_limit"]; ok { - delete(mopts, "dentry_cache_limit") + if str, ok := mopts[moptDentryCacheLimit]; ok { + delete(mopts, moptDentryCacheLimit) maxCachedDentries, err := strconv.ParseUint(str, 10, 64) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str) return nil, nil, syserror.EINVAL } fsopts.maxCachedDentries = maxCachedDentries } // Handle simple flags. - if _, ok := mopts["force_page_cache"]; ok { - delete(mopts, "force_page_cache") + if _, ok := mopts[moptForcePageCache]; ok { + delete(mopts, moptForcePageCache) fsopts.forcePageCache = true } - if _, ok := mopts["limit_host_fd_translation"]; ok { - delete(mopts, "limit_host_fd_translation") + if _, ok := mopts[moptLimitHostFDTranslation]; ok { + delete(mopts, moptLimitHostFDTranslation) fsopts.limitHostFDTranslation = true } - if _, ok := mopts["overlayfs_stale_read"]; ok { - delete(mopts, "overlayfs_stale_read") + if _, ok := mopts[moptOverlayfsStaleRead]; ok { + delete(mopts, moptOverlayfsStaleRead) fsopts.overlayfsStaleRead = true } // fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying @@ -469,34 +497,34 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { // Check that the transport is "fd". - trans, ok := mopts["trans"] - if !ok || trans != "fd" { - ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'") + trans, ok := mopts[moptTransport] + if !ok || trans != transportModeFD { + ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as '%s=%s'", moptTransport, transportModeFD) return -1, syserror.EINVAL } - delete(mopts, "trans") + delete(mopts, moptTransport) // Check that read and write FDs are provided and identical. - rfdstr, ok := mopts["rfdno"] + rfdstr, ok := mopts[moptReadFD] if !ok { - ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'") + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as '%s=<file descriptor>'", moptReadFD) return -1, syserror.EINVAL } - delete(mopts, "rfdno") + delete(mopts, moptReadFD) rfd, err := strconv.Atoi(rfdstr) if err != nil { - ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr) + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: %s=%s", moptReadFD, rfdstr) return -1, syserror.EINVAL } - wfdstr, ok := mopts["wfdno"] + wfdstr, ok := mopts[moptWriteFD] if !ok { - ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'") + ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as '%s=<file descriptor>'", moptWriteFD) return -1, syserror.EINVAL } - delete(mopts, "wfdno") + delete(mopts, moptWriteFD) wfd, err := strconv.Atoi(wfdstr) if err != nil { - ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr) + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: %s=%s", moptWriteFD, wfdstr) return -1, syserror.EINVAL } if rfd != wfd { diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index ad5de80dc..b9cce4181 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -260,6 +260,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s unix.Stat_t diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index e63588e33..1cd3137e6 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -67,6 +67,11 @@ type filesystem struct { kernfs.Filesystem } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + type file struct { kernfs.DynamicBytesFile content string diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index f7f795b10..84e37f793 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -85,6 +85,8 @@ func putDentrySlice(ds *[]*dentry) { // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. +// +// +checklocks:fs.renameMu func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) { fs.renameMu.RUnlock() if *dsp == nil { @@ -110,6 +112,7 @@ func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]* putDentrySlice(*dsp) } +// +checklocks:fs.renameMu func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { if *ds == nil { fs.renameMu.Unlock() @@ -1761,3 +1764,15 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + // Return the mount options from the topmost layer. + var vd vfs.VirtualDentry + if fs.opts.UpperRoot.Ok() { + vd = fs.opts.UpperRoot + } else { + vd = fs.opts.LowerRoots[0] + } + return vd.Mount().Filesystem().Impl().MountOptions() +} diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 429733c10..3f05e444e 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -80,6 +80,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // inode implements kernfs.Inode. // // +stateify savable diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 8716d0a3c..254a8b062 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -104,6 +104,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.Release(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return fmt.Sprintf("dentry_cache_limit=%d", fs.MaxCachedDentries) +} + // dynamicInode is an overfitted interface for common Inodes with // dynamicByteSource types used in procfs. // diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index fd7823daa..fb274b78e 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -17,6 +17,7 @@ package proc import ( "bytes" "fmt" + "math" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -69,17 +70,17 @@ func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]kernfs.Inode{ "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ - "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), - "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), - "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), - "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), - "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), + "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), + "ip_local_port_range": fs.newInode(ctx, root, 0644, &portRange{stack: stack}), + "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")), "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")), "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")), "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")), @@ -421,3 +422,68 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs } return n, nil } + +// portRange implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/ip_local_port_range. +// +// +stateify savable +type portRange struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` + + // start and end store the port range. We must save/restore this here, + // since a netstack instance is created on restore. + start *uint16 + end *uint16 +} + +var _ vfs.WritableDynamicBytesSource = (*portRange)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error { + if pr.start == nil { + start, end := pr.stack.PortRange() + pr.start = &start + pr.end = &end + } + _, err := fmt.Fprintf(buf, "%d %d\n", *pr.start, *pr.end) + return err +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is + // large. + src = src.TakeFirst(usermem.PageSize - 1) + + ports := make([]int32, 2) + n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts) + if err != nil { + return 0, err + } + + // Port numbers must be uint16s. + if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 { + return 0, syserror.EINVAL + } + + if err := pr.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil { + return 0, err + } + if pr.start == nil { + pr.start = new(uint16) + pr.end = new(uint16) + } + *pr.start = uint16(ports[0]) + *pr.end = uint16(ports[1]) + return n, nil +} diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index fda1fa942..735756280 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -85,6 +85,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // inode implements kernfs.Inode. // // +stateify savable diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index dbd9ebdda..1d9280dae 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -143,6 +143,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.Release(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return fmt.Sprintf("dentry_cache_limit=%d", fs.MaxCachedDentries) +} + // dir implements kernfs.Inode. // // +stateify savable diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 4f675c21e..5fdca1d46 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -898,3 +898,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe d = d.parent } } + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return fs.mopts +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index a01e413e0..8df81f589 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -70,6 +70,10 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 + // mopts contains the tmpfs-specific mount options passed to this + // filesystem. Immutable. + mopts string + // mu serializes changes to the Dentry tree. mu sync.RWMutex `state:"nosave"` @@ -184,6 +188,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mfp: mfp, clock: clock, devMinor: devMinor, + mopts: opts.Data, } fs.vfsfs.Init(vfsObj, newFSType, &fs) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 9057d2b4e..6cb1a23e0 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -590,6 +590,23 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, return nil, err } + // Clear the Merkle tree file if they are to be generated at runtime. + // TODO(b/182315468): Optimize the Merkle tree generate process to + // allow only updating certain files/directories. + if fs.allowRuntimeEnable { + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: childMerkleVD, + Start: childMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_TRUNC, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + } + // The dentry needs to be cleaned up if any error occurs. IncRef will be // called if a verity child dentry is successfully created. defer childMerkleVD.DecRef(ctx) diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 374f71568..0d9b0ee2c 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -38,6 +38,7 @@ import ( "fmt" "math" "strconv" + "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -310,6 +311,24 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt d.DecRef(ctx) return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") } + + // Clear the Merkle tree file if they are to be generated at runtime. + // TODO(b/182315468): Optimize the Merkle tree generate process to + // allow only updating certain files/directories. + if fs.allowRuntimeEnable { + lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_TRUNC, + Mode: 0644, + }) + if err != nil { + return nil, nil, err + } + lowerMerkleFD.DecRef(ctx) + } + d.lowerMerkleVD = lowerMerkleVD // Get metadata from the underlying file system. @@ -418,6 +437,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.lowerMount.DecRef(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -750,6 +774,50 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + if !fd.d.isDir() { + return syserror.ENOTDIR + } + fd.mu.Lock() + defer fd.mu.Unlock() + + var ds []vfs.Dirent + err := fd.lowerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { + // Do not include the Merkle tree files. + if strings.Contains(dirent.Name, merklePrefix) || strings.Contains(dirent.Name, merkleRootPrefix) { + return nil + } + if fd.d.verityEnabled() { + // Verify that the child is expected. + if dirent.Name != "." && dirent.Name != ".." { + if _, ok := fd.d.childrenNames[dirent.Name]; !ok { + return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name)) + } + } + } + ds = append(ds, dirent) + return nil + })) + + if err != nil { + return err + } + + // The result should contain all children plus "." and "..". + if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 { + return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) + } + + for fd.off < int64(len(ds)) { + if err := cb.Handle(ds[fd.off]); err != nil { + return err + } + fd.off++ + } + return nil +} + // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { fd.mu.Lock() diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index f31277d30..6b71bd3a9 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -93,6 +93,14 @@ type Stack interface { // SetForwarding enables or disables packet forwarding between NICs. SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error + + // PortRange returns the UDP and TCP inclusive range of ephemeral ports + // used in both IPv4 and IPv6. + PortRange() (uint16, uint16) + + // SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range + // (inclusive). + SetPortRange(start uint16, end uint16) error } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 9ebeba8a3..03e2608c2 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -164,3 +164,15 @@ func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable b s.IPForwarding = enable return nil } + +// PortRange implements inet.Stack.PortRange. +func (*TestStack) PortRange() (uint16, uint16) { + // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). + return 32768, 28232 +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (*TestStack) SetPortRange(start uint16, end uint16) error { + // No-op. + return nil +} diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 58cc11a13..a4af3e21b 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -876,6 +876,7 @@ func (f *MemoryFile) UpdateUsage() error { // in bs, sets committed[i] to 1 if the page is committed and 0 otherwise. // // Precondition: f.mu must be held; it may be unlocked and reacquired. +// +checklocks:f.mu func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(bs []byte, committed []byte) error) error { // Track if anything changed to elide the merge. In the common case, we // expect all segments to be committed and no merge to occur. @@ -925,72 +926,73 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( r := seg.Range() var checkErr error - err := f.forEachMappingSlice(r, func(s []byte) { - if checkErr != nil { - return - } - - // Ensure that we have sufficient buffer for the call - // (one byte per page). The length of each slice must - // be page-aligned. - bufLen := len(s) / usermem.PageSize - if len(buf) < bufLen { - buf = make([]byte, bufLen) - } + err := f.forEachMappingSlice(r, + func(s []byte) { + if checkErr != nil { + return + } - // Query for new pages in core. - // NOTE(b/165896008): mincore (which is passed as checkCommitted) - // by f.UpdateUsage() might take a really long time. So unlock f.mu - // while checkCommitted runs. - f.mu.Unlock() - err := checkCommitted(s, buf) - f.mu.Lock() - if err != nil { - checkErr = err - return - } + // Ensure that we have sufficient buffer for the call + // (one byte per page). The length of each slice must + // be page-aligned. + bufLen := len(s) / usermem.PageSize + if len(buf) < bufLen { + buf = make([]byte, bufLen) + } - // Scan each page and switch out segments. - seg := f.usage.LowerBoundSegment(r.Start) - for i := 0; i < bufLen; { - if buf[i]&0x1 == 0 { - i++ - continue + // Query for new pages in core. + // NOTE(b/165896008): mincore (which is passed as checkCommitted) + // by f.UpdateUsage() might take a really long time. So unlock f.mu + // while checkCommitted runs. + f.mu.Unlock() + err := checkCommitted(s, buf) + f.mu.Lock() + if err != nil { + checkErr = err + return } - // Scan to the end of this committed range. - j := i + 1 - for ; j < bufLen; j++ { - if buf[j]&0x1 == 0 { - break + + // Scan each page and switch out segments. + seg := f.usage.LowerBoundSegment(r.Start) + for i := 0; i < bufLen; { + if buf[i]&0x1 == 0 { + i++ + continue } - } - committedFR := memmap.FileRange{ - Start: r.Start + uint64(i*usermem.PageSize), - End: r.Start + uint64(j*usermem.PageSize), - } - // Advance seg to committedFR.Start. - for seg.Ok() && seg.End() < committedFR.Start { - seg = seg.NextSegment() - } - // Mark pages overlapping committedFR as committed. - for seg.Ok() && seg.Start() < committedFR.End { - if seg.ValuePtr().canCommit() { - seg = f.usage.Isolate(seg, committedFR) - seg.ValuePtr().knownCommitted = true - amount := seg.Range().Length() - usage.MemoryAccounting.Inc(amount, seg.ValuePtr().kind) - f.usageExpected += amount - changedAny = true + // Scan to the end of this committed range. + j := i + 1 + for ; j < bufLen; j++ { + if buf[j]&0x1 == 0 { + break + } } - seg = seg.NextSegment() + committedFR := memmap.FileRange{ + Start: r.Start + uint64(i*usermem.PageSize), + End: r.Start + uint64(j*usermem.PageSize), + } + // Advance seg to committedFR.Start. + for seg.Ok() && seg.End() < committedFR.Start { + seg = seg.NextSegment() + } + // Mark pages overlapping committedFR as committed. + for seg.Ok() && seg.Start() < committedFR.End { + if seg.ValuePtr().canCommit() { + seg = f.usage.Isolate(seg, committedFR) + seg.ValuePtr().knownCommitted = true + amount := seg.Range().Length() + usage.MemoryAccounting.Inc(amount, seg.ValuePtr().kind) + f.usageExpected += amount + changedAny = true + } + seg = seg.NextSegment() + } + // Continue scanning for committed pages. + i = j + 1 } - // Continue scanning for committed pages. - i = j + 1 - } - // Advance r.Start. - r.Start += uint64(len(s)) - }) + // Advance r.Start. + r.Start += uint64(len(s)) + }) if checkErr != nil { return checkErr } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index e6323244c..5bcf92e14 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -504,3 +504,14 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } + +// PortRange implements inet.Stack.PortRange. +func (*Stack) PortRange() (uint16, uint16) { + // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). + return 32768, 28232 +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (*Stack) SetPortRange(start uint16, end uint16) error { + return syserror.EACCES +} diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f2dc7c90b..9efb195f0 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -83,110 +83,121 @@ var Metrics = tcpip.Stats{ V4: tcpip.ICMPv4Stats{ PacketsSent: tcpip.ICMPv4SentPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent by netstack."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped by netstack due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped by netstack due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received by netstack."), }, - Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), + Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Number of ICMPv4 packets received that the transport layer could not parse."), }, }, V6: tcpip.ICMPv6Stats{ PacketsSent: tcpip.ICMPv6SentPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent by netstack."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent by netstack."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped by netstack due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped by netstack due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received by netstack."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received by netstack."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), }, - Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), + Unrecognized: mustCreateMetric("/netstack/icmp/v6/packets_received/unrecognized", "Number of ICMPv6 packets received that the transport layer does not know how to parse."), + Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Number of ICMPv6 packets received that the transport layer could not parse."), + RouterOnlyPacketsDroppedByHost: mustCreateMetric("/netstack/icmp/v6/packets_received/router_only_packets_dropped_by_host", "Number of ICMPv6 packets dropped due to being router-specific packets."), }, }, }, IGMP: tcpip.IGMPStats{ PacketsSent: tcpip.IGMPSentPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent by netstack."), }, - Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped by netstack due to link layer errors."), }, PacketsReceived: tcpip.IGMPReceivedPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received by netstack."), }, - Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."), - ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."), - Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received by netstack that could not be parsed."), + ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Number of received IGMP packets with bad checksums."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received by netstack."), }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), - MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), - MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), - IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Total number of IP packets dropped in the Prerouting chain."), - IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."), - IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."), - OptionTimestampReceived: mustCreateMetric("/netstack/ip/options/timestamp_received", "Total number of timestamp options found in received IP packets."), - OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Total number of record route options found in received IP packets."), - OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Total number of router alert options found in received IP packets."), - OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Total number of unknown options found in received IP packets."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + DisabledPacketsReceived: mustCreateMetric("/netstack/ip/disabled_packets_received", "Number of IP packets received from the link layer when the IP layer is disabled."), + InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Number of IP packets received with an unknown or invalid destination address."), + InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Number of IP packets received with an unknown or invalid source address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Number of IP fragments which failed IP fragment validation checks."), + IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Number of IP packets dropped in the Prerouting chain."), + IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Number of IP packets dropped in the Input chain."), + IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Number of IP packets dropped in the Output chain."), + OptionTimestampReceived: mustCreateMetric("/netstack/ip/options/timestamp_received", "Number of timestamp options found in received IP packets."), + OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Number of record route options found in received IP packets."), + OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Number of router alert options found in received IP packets."), + OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Number of unknown options found in received IP packets."), }, ARP: tcpip.ARPStats{ PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."), diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index cc0fadeb5..b215067cf 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -336,7 +336,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.ParamProblem.Value(), // InParmProbs. in.SrcQuench.Value(), // InSrcQuenchs. in.Redirect.Value(), // InRedirects. - in.Echo.Value(), // InEchos. + in.EchoRequest.Value(), // InEchos. in.EchoReply.Value(), // InEchoReps. in.Timestamp.Value(), // InTimestamps. in.TimestampReply.Value(), // InTimestampReps. @@ -349,7 +349,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { out.ParamProblem.Value(), // OutParmProbs. out.SrcQuench.Value(), // OutSrcQuenchs. out.Redirect.Value(), // OutRedirects. - out.Echo.Value(), // OutEchos. + out.EchoRequest.Value(), // OutEchos. out.EchoReply.Value(), // OutEchoReps. out.Timestamp.Value(), // OutTimestamps. out.TimestampReply.Value(), // OutTimestampReps. @@ -478,3 +478,13 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) } return nil } + +// PortRange implements inet.Stack.PortRange. +func (s *Stack) PortRange() (uint16, uint16) { + return s.Stack.PortRange() +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (s *Stack) SetPortRange(start uint16, end uint16) error { + return syserr.TranslateNetstackError(s.Stack.SetPortRange(start, end)).ToError() +} diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 7ad0eaf86..3caf417ca 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -291,6 +291,11 @@ func (fs *anonFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDe return PrependPathSyntheticError{} } +// MountOptions implements FilesystemImpl.MountOptions. +func (fs *anonFilesystem) MountOptions() string { + return "" +} + // IncRef implements DentryImpl.IncRef. func (d *anonDentry) IncRef() { // no-op diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 320ab7ce1..e7ca24d96 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -211,12 +211,14 @@ func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dent // AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion // fails. +// +checklocks:d.mu func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) { d.mu.Unlock() } // CommitDeleteDentry must be called after PrepareDeleteDentry if the deletion // succeeds. +// +checklocks:d.mu func (vfs *VirtualFilesystem) CommitDeleteDentry(ctx context.Context, d *Dentry) { d.dead = true d.mu.Unlock() @@ -270,6 +272,8 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t // AbortRenameDentry must be called after PrepareRenameDentry if the rename // fails. +// +checklocks:from.mu +// +checklocks:to.mu func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) { from.mu.Unlock() if to != nil { @@ -282,6 +286,8 @@ func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) { // that was replaced by from. // // Preconditions: PrepareRenameDentry was previously called on from and to. +// +checklocks:from.mu +// +checklocks:to.mu func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, from, to *Dentry) { from.mu.Unlock() if to != nil { @@ -297,6 +303,8 @@ func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, fro // from and to are exchanged by rename(RENAME_EXCHANGE). // // Preconditions: PrepareRenameDentry was previously called on from and to. +// +checklocks:from.mu +// +checklocks:to.mu func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) { from.mu.Unlock() to.mu.Unlock() diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 2c4b81e78..059939010 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -502,6 +502,15 @@ type FilesystemImpl interface { // // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error + + // MountOptions returns mount options for the current filesystem. This + // should only return options specific to the filesystem (i.e. don't return + // "ro", "rw", etc). Options should be returned as a comma-separated string, + // similar to the input to the 5th argument to mount. + // + // If the implementation has no filesystem-specific options, it should + // return the empty string. + MountOptions() string } // PrependPathAtVFSRootError is returned by implementations of diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 7063066ff..922f9e697 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -217,20 +217,21 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr return err } vfs.mountMu.Lock() - vd.dentry.mu.Lock() + vdDentry := vd.dentry + vdDentry.mu.Lock() for { - if vd.dentry.dead { - vd.dentry.mu.Unlock() + if vdDentry.dead { + vdDentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef(ctx) return syserror.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and // vfs.mountMu.Lock(). - if !vd.dentry.isMounted() { + if !vdDentry.isMounted() { break } - nextmnt := vfs.mounts.Lookup(vd.mount, vd.dentry) + nextmnt := vfs.mounts.Lookup(vd.mount, vdDentry) if nextmnt == nil { break } @@ -243,13 +244,13 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr } // This can't fail since we're holding vfs.mountMu. nextmnt.root.IncRef() - vd.dentry.mu.Unlock() + vdDentry.mu.Unlock() vd.DecRef(ctx) vd = VirtualDentry{ mount: nextmnt, dentry: nextmnt.root, } - vd.dentry.mu.Lock() + vdDentry.mu.Lock() } // TODO(gvisor.dev/issue/1035): Linux requires that either both the mount // point and the mount root are directories, or neither are, and returns @@ -258,7 +259,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr vfs.mounts.seq.BeginWrite() vfs.connectLocked(mnt, vd, mntns) vfs.mounts.seq.EndWrite() - vd.dentry.mu.Unlock() + vdDentry.mu.Unlock() vfs.mountMu.Unlock() return nil } @@ -958,13 +959,17 @@ func manglePath(p string) string { // superBlockOpts returns the super block options string for the the mount at // the given path. func superBlockOpts(mountPath string, mnt *Mount) string { - // gVisor doesn't (yet) have a concept of super block options, so we - // use the ro/rw bit from the mount flag. + // Compose super block options by combining global mount flags with + // FS-specific mount options. opts := "rw" if mnt.ReadOnly() { opts = "ro" } + if mopts := mnt.fs.Impl().MountOptions(); mopts != "" { + opts += "," + mopts + } + // NOTE(b/147673608): If the mount is a cgroup, we also need to include // the cgroup name in the options. For now we just read that from the // path. diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go index 4e49a9b89..411a80a8a 100644 --- a/pkg/sync/mutex_unsafe.go +++ b/pkg/sync/mutex_unsafe.go @@ -72,6 +72,7 @@ func (m *Mutex) Lock() { // Preconditions: // * m is locked. // * m was locked by this goroutine. +// +checklocksignore func (m *Mutex) Unlock() { noteUnlock(unsafe.Pointer(m)) m.m.Unlock() diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go index 4cf3fcd6e..892d3e641 100644 --- a/pkg/sync/rwmutex_unsafe.go +++ b/pkg/sync/rwmutex_unsafe.go @@ -105,6 +105,7 @@ func (rw *CrossGoroutineRWMutex) RUnlock() { // TryLock locks rw for writing. It returns true if it succeeds and false // otherwise. It does not block. +// +checklocksignore func (rw *CrossGoroutineRWMutex) TryLock() bool { if RaceEnabled { RaceDisable() @@ -155,6 +156,7 @@ func (rw *CrossGoroutineRWMutex) Lock() { // // Preconditions: // * rw is locked for writing. +// +checklocksignore func (rw *CrossGoroutineRWMutex) Unlock() { if RaceEnabled { RaceRelease(unsafe.Pointer(&rw.writerSem)) @@ -181,6 +183,7 @@ func (rw *CrossGoroutineRWMutex) Unlock() { // // Preconditions: // * rw is locked for writing. +// +checklocksignore func (rw *CrossGoroutineRWMutex) DowngradeLock() { if RaceEnabled { RaceRelease(unsafe.Pointer(&rw.readerSem)) @@ -250,6 +253,7 @@ func (rw *RWMutex) RLock() { // Preconditions: // * rw is locked for reading. // * rw was locked by this goroutine. +// +checklocksignore func (rw *RWMutex) RUnlock() { rw.m.RUnlock() noteUnlock(unsafe.Pointer(rw)) @@ -279,6 +283,7 @@ func (rw *RWMutex) Lock() { // Preconditions: // * rw is locked for writing. // * rw was locked by this goroutine. +// +checklocksignore func (rw *RWMutex) Unlock() { rw.m.Unlock() noteUnlock(unsafe.Pointer(rw)) @@ -288,6 +293,7 @@ func (rw *RWMutex) Unlock() { // // Preconditions: // * rw is locked for writing. +// +checklocksignore func (rw *RWMutex) DowngradeLock() { // No note change for DowngradeLock. rw.m.DowngradeLock() diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 0b9139570..79e564de6 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -51,6 +51,7 @@ var ( ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM) ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT) ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), linux.EINVAL) + ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), linux.EINVAL) ) // TranslateNetstackError converts an error from the tcpip package to a sentry @@ -135,6 +136,8 @@ func TranslateNetstackError(err tcpip.Error) *Error { return ErrBadBuffer case *tcpip.ErrMalformedHeader: return ErrMalformedHeader + case *tcpip.ErrInvalidPortRange: + return ErrInvalidPortRange default: panic(fmt.Sprintf("unknown error %T", err)) } diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index b05e81526..f4a30effd 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -196,7 +196,7 @@ func (vv *VectorisedView) CapLength(length int) { // If the buffer argument is large enough to contain all the Views of this // VectorisedView, the method will avoid allocations and use the buffer to // store the Views of the clone. -func (vv *VectorisedView) Clone(buffer []View) VectorisedView { +func (vv VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } @@ -290,6 +290,14 @@ func (vv *VectorisedView) AppendView(v View) { vv.size += len(v) } +// AppendViews appends views to vv. +func (vv *VectorisedView) AppendViews(views []View) { + vv.views = append(vv.views, views...) + for _, v := range views { + vv.size += len(v) + } +} + // Readers returns a bytes.Reader for each of vv's views. func (vv *VectorisedView) Readers() []bytes.Reader { readers := make([]bytes.Reader, 0, len(vv.views)) diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 78b2faa26..d296d9c2b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -45,6 +45,11 @@ func vv(size int, pieces ...string) buffer.VectorisedView { return buffer.NewVectorisedView(size, views) } +// v returns a buffer.View containing piece. +func v(piece string) buffer.View { + return buffer.View(piece) +} + var capLengthTestCases = []struct { comment string in buffer.VectorisedView @@ -125,6 +130,12 @@ var trimFrontTestCases = []struct { want: vv(1, "3"), }, { + comment: "Case with one empty Views", + in: vv(3, "1", "", "23"), + count: 2, + want: vv(1, "3"), + }, + { comment: "Corner case with negative count", in: vv(1, "1"), count: -1, @@ -566,11 +577,11 @@ func TestAppendView(t *testing.T) { in buffer.View want buffer.VectorisedView }{ - {buffer.VectorisedView{}, nil, buffer.VectorisedView{}}, - {buffer.VectorisedView{}, buffer.View{}, buffer.VectorisedView{}}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), nil, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{}, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{'e'}, buffer.NewVectorisedView(5, []buffer.View{{'a', 'b', 'c', 'd'}, {'e'}})}, + {vv(0), nil, vv(0)}, + {vv(0), v(""), vv(0)}, + {vv(4, "abcd"), nil, vv(4, "abcd")}, + {vv(4, "abcd"), v(""), vv(4, "abcd")}, + {vv(4, "abcd"), v("e"), vv(5, "abcd", "e")}, } for _, tc := range testCases { tc.vv.AppendView(tc.in) @@ -580,6 +591,31 @@ func TestAppendView(t *testing.T) { } } +func TestAppendViews(t *testing.T) { + testCases := []struct { + vv buffer.VectorisedView + in []buffer.View + want buffer.VectorisedView + }{ + {vv(0), nil, vv(0)}, + {vv(0), []buffer.View{}, vv(0)}, + {vv(0), []buffer.View{v("")}, vv(0, "")}, + {vv(4, "abcd"), nil, vv(4, "abcd")}, + {vv(4, "abcd"), []buffer.View{}, vv(4, "abcd")}, + {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")}, + {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")}, + {vv(4, "abcd"), []buffer.View{v("e")}, vv(5, "abcd", "e")}, + {vv(4, "abcd"), []buffer.View{v("e"), v("fg")}, vv(7, "abcd", "e", "fg")}, + {vv(4, "abcd"), []buffer.View{v(""), v("fg")}, vv(6, "abcd", "", "fg")}, + } + for _, tc := range testCases { + tc.vv.AppendViews(tc.in) + if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) + } + } +} + func TestMemSize(t *testing.T) { const perViewCap = 128 views := make([]buffer.View, 2, 32) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 07b4393a4..fc622b246 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -567,7 +567,7 @@ func TCPWindowLessThanEq(window uint16) TransportChecker { } // TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags uint8) TransportChecker { +func TCPFlags(flags header.TCPFlags) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() @@ -576,15 +576,15 @@ func TCPFlags(flags uint8) TransportChecker { t.Fatalf("TCP header not found in h: %T", h) } - if f := tcp.Flags(); f != flags { - t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) + if got := tcp.Flags(); got != flags { + t.Errorf("got tcp.Flags() = %s, want %s", got, flags) } } } // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the // given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask uint8) TransportChecker { +func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() @@ -593,8 +593,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { t.Fatalf("TCP header not found in h: %T", h) } - if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) + if got := tcp.Flags(); (got & mask) != (flags & mask) { + t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask) } } } @@ -985,7 +985,11 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { } icmp := header.ICMPv6(last.Payload()) - if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { + if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: last.SourceAddress(), + Dst: last.DestinationAddress(), + }); got != want { t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) } diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go index 3b7cc52f3..5d478ac32 100644 --- a/pkg/tcpip/errors.go +++ b/pkg/tcpip/errors.go @@ -300,6 +300,19 @@ func (*ErrInvalidOptionValue) IgnoreStats() bool { } func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" } +// ErrInvalidPortRange indicates an attempt to set an invalid port range. +// +// +stateify savable +type ErrInvalidPortRange struct{} + +func (*ErrInvalidPortRange) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidPortRange) IgnoreStats() bool { + return true +} +func (*ErrInvalidPortRange) String() string { return "invalid port range" } + // ErrMalformedHeader indicates the operation encountered a malformed header. // // +stateify savable diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 14a4b2b44..6aa9acfa8 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -186,42 +186,29 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) + var c Checksumer + for _, v := range vv.Views() { + c.Add([]byte(v)) + } + return ChecksumCombine(initial, c.Checksum()) } -// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the -// bytes in the given VectorizedView. -// -// The initial checksum must have been computed on an even number of bytes. -func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { - odd := false - sum := initial - for _, v := range vv.Views() { - if len(v) == 0 { - continue - } - - if off >= len(v) { - off -= len(v) - continue - } - v = v[off:] - - l := len(v) - if l > size { - l = size - } - v = v[:l] - - sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum)) - - size -= len(v) - if size == 0 { - break - } - off = 0 +// Checksumer calculates checksum defined in RFC 1071. +type Checksumer struct { + sum uint16 + odd bool +} + +// Add adds b to checksum. +func (c *Checksumer) Add(b []byte) { + if len(b) > 0 { + c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum)) } - return sum +} + +// Checksum returns the latest checksum value. +func (c *Checksumer) Checksum() uint16 { + return c.sum } // ChecksumCombine combines the two uint16 to form their checksum. This is done diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 5ab20ee86..d267dabd0 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -17,6 +17,7 @@ package header_test import ( + "bytes" "fmt" "math/rand" "sync" @@ -26,86 +27,72 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -func TestChecksumVVWithOffset(t *testing.T) { +func TestChecksumer(t *testing.T) { testCases := []struct { - name string - vv buffer.VectorisedView - off, size int - initial uint16 - want uint16 + name string + data [][]byte + want uint16 }{ { name: "empty", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 0, want: 0, }, { - name: "OneView", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 5, + name: "OneOddView", + data: [][]byte{ + []byte{1, 9, 0, 5, 4}, + }, want: 1294, }, { - name: "TwoViews", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 0, - size: 11, + name: "TwoOddViews", + data: [][]byte{ + []byte{1, 9, 0, 5, 4}, + []byte{4, 3, 7, 1, 2, 123}, + }, want: 33819, }, { - name: "TwoViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 1, - size: 11, - want: 33819, + name: "OneEvenView", + data: [][]byte{ + []byte{1, 9, 0, 5}, + }, + want: 270, }, { - name: "ThreeViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 7, - size: 11, - want: 33819, + name: "TwoEvenViews", + data: [][]byte{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0}), + buffer.NewViewFromBytes([]byte{9, 0, 5, 4}), + }, + want: 30981, }, { - name: "ThreeViewsWithInitial", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), - }), - initial: 77, - off: 7, - size: 11, - want: 33896, + name: "ThreeViews", + data: [][]byte{ + []byte{77, 11, 33, 0, 55, 44}, + []byte{98, 1, 9, 0, 5, 4}, + []byte{4, 3, 7, 1, 2, 123, 99}, + }, + want: 34236, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { - t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) + var all bytes.Buffer + var c header.Checksumer + for _, b := range tc.data { + c.Add(b) + // Append to the buffer. We will check the checksum as a whole later. + if _, err := all.Write(b); err != nil { + t.Fatalf("all.Write(b) = _, %s; want _, nil", err) + } + } + if got, want := c.Checksum(), tc.want; got != want { + t.Errorf("c.Checksum() = %d, want %d", got, want) } - v := tc.vv.ToView() - v.TrimFront(tc.off) - v.CapLength(tc.size) - if got, want := header.Checksum(v, tc.initial), tc.want; got != want { - t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) + if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want { + t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want) } }) } @@ -228,7 +215,7 @@ func TestICMPv4Checksum(t *testing.T) { h.SetChecksum(want) testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv4Checksum(h, vv) + return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0)) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } @@ -260,6 +247,12 @@ func TestICMPv6Checksum(t *testing.T) { h.SetChecksum(want) testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv6Checksum(h, src, dst, vv) + return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: h, + Src: src, + Dst: dst, + PayloadCsum: header.ChecksumVV(vv, 0), + PayloadLen: vv.Size(), + }) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index f840a4322..91c1c3cd2 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -18,7 +18,6 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. @@ -198,8 +197,8 @@ func (b ICMPv4) SetSequence(sequence uint16) { // ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, // and payload. -func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { - xsum := ChecksumVV(vv, 0) +func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 { + xsum := payloadCsum // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. xsum = Checksum(h[:2], xsum) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index eca9750ab..668da623a 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -18,7 +18,6 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. @@ -262,12 +261,22 @@ func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } +// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum. +type ICMPv6ChecksumParams struct { + Header ICMPv6 + Src tcpip.Address + Dst tcpip.Address + PayloadCsum uint16 + PayloadLen int +} + // ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. -func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) +func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 { + h := params.Header - xsum = ChecksumVV(vv, xsum) + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen)) + xsum = ChecksumCombine(xsum, params.PayloadCsum) // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. xsum = Checksum(h[:2], xsum) diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 2042f214a..ebb4b2c1d 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -41,7 +41,7 @@ func ARP(pkt *stack.PacketBuffer) bool { // // Returns true if the header was successfully parsed. func IPv4(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + hdr, ok := pkt.Data().PullUp(header.IPv4MinimumSize) if !ok { return false } @@ -62,27 +62,29 @@ func IPv4(pkt *stack.PacketBuffer) bool { ipHdr = header.IPv4(hdr) pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr)) return true } // IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network // header with the IPv6 header. func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + hdr, ok := pkt.Data().PullUp(header.IPv6MinimumSize) if !ok { return 0, 0, 0, false, false } ipHdr := header.IPv6(hdr) - // dataClone consists of: + // Create a VV to parse the packet. We don't plan to modify anything here. + // dataVV consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. views := [8]buffer.View{} - dataClone := pkt.Data.Clone(views[:]) - dataClone.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + dataVV := buffer.NewVectorisedView(0, views[:0]) + dataVV.AppendViews(pkt.Data().Views()) + dataVV.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataVV) // Iterate over the IPv6 extensions to find their length. var nextHdr tcpip.TransportProtocolNumber @@ -98,7 +100,7 @@ traverseExtensions: // If we exhaust the extension list, the entire packet is the IPv6 header // and (possibly) extensions. if done { - extensionsSize = dataClone.Size() + extensionsSize = dataVV.Size() break } @@ -110,12 +112,12 @@ traverseExtensions: fragMore = extHdr.More() } rawPayload := it.AsRawHeader(true /* consume */) - extensionsSize = dataClone.Size() - rawPayload.Buf.Size() + extensionsSize = dataVV.Size() - rawPayload.Buf.Size() break traverseExtensions case header.IPv6RawPayloadHeader: // We've found the payload after any extensions. - extensionsSize = dataClone.Size() - extHdr.Buf.Size() + extensionsSize = dataVV.Size() - extHdr.Buf.Size() nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) break traverseExtensions @@ -127,10 +129,10 @@ traverseExtensions: // Put the IPv6 header with extensions in pkt.NetworkHeader(). hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) if !ok { - panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data().Size())) } ipHdr = header.IPv6(hdr) - pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.Data().CapLength(int(ipHdr.PayloadLength())) pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber return nextHdr, fragID, fragOffset, fragMore, true @@ -153,13 +155,13 @@ func UDP(pkt *stack.PacketBuffer) bool { func TCP(pkt *stack.PacketBuffer) bool { // TCP header is variable length, peek at it first. hdrLen := header.TCPMinimumSize - hdr, ok := pkt.Data.PullUp(hdrLen) + hdr, ok := pkt.Data().PullUp(hdrLen) if !ok { return false } // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data().Size() { // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of // packets. hdrLen = offset diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 4c6f808e5..adc835d30 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -45,9 +45,23 @@ const ( TCPMaxSACKBlocks = 4 ) +// TCPFlags is the dedicated type for TCP flags. +type TCPFlags uint8 + +// String implements Stringer.String. +func (f TCPFlags) String() string { + flagsStr := []byte("FSRPAU") + for i := range flagsStr { + if f&(1<<uint(i)) == 0 { + flagsStr[i] = ' ' + } + } + return string(flagsStr) +} + // Flags that may be set in a TCP segment. const ( - TCPFlagFin = 1 << iota + TCPFlagFin TCPFlags = 1 << iota TCPFlagSyn TCPFlagRst TCPFlagPsh @@ -94,7 +108,7 @@ type TCPFields struct { DataOffset uint8 // Flags is the "flags" field of a TCP packet. - Flags uint8 + Flags TCPFlags // WindowSize is the "window size" field of a TCP packet. WindowSize uint16 @@ -234,8 +248,8 @@ func (b TCP) Payload() []byte { } // Flags returns the flags field of the tcp header. -func (b TCP) Flags() uint8 { - return b[TCPFlagsOffset] +func (b TCP) Flags() TCPFlags { + return TCPFlags(b[TCPFlagsOffset]) } // WindowSize returns the "window size" field of the tcp header. @@ -319,10 +333,10 @@ func (b TCP) ParsedOptions() TCPOptions { return ParseTCPOptions(b.Options()) } -func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) { +func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq) binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack) - b[TCPFlagsOffset] = flags + b[TCPFlagsOffset] = uint8(flags) binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } @@ -338,7 +352,7 @@ func (b TCP) Encode(t *TCPFields) { // EncodePartial updates a subset of the fields of the tcp header. It is useful // in cases when similar segments are produced. -func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) { +func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { // Add the total length and "flags" field contributions to the checksum. // We don't use the flags field directly from the header because it's a // one-byte field with an odd offset, so it would be accounted for diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go index 72563837b..96db8460f 100644 --- a/pkg/tcpip/header/tcp_test.go +++ b/pkg/tcpip/header/tcp_test.go @@ -146,3 +146,23 @@ func TestTCPParseOptions(t *testing.T) { } } } + +func TestTCPFlags(t *testing.T) { + for _, tt := range []struct { + flags header.TCPFlags + want string + }{ + {header.TCPFlagFin, "F "}, + {header.TCPFlagSyn, " S "}, + {header.TCPFlagRst, " R "}, + {header.TCPFlagPsh, " P "}, + {header.TCPFlagAck, " A "}, + {header.TCPFlagUrg, " U"}, + {header.TCPFlagSyn | header.TCPFlagAck, " S A "}, + {header.TCPFlagFin | header.TCPFlagAck, "F A "}, + } { + if got := tt.flags.String(); got != tt.want { + t.Errorf("got TCPFlags(%#b).String() = %s, want = %s", tt.flags, got, tt.want) + } + } +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 72d3f70ac..e17e2085c 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -427,7 +427,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen vnetHdr.csumOffset = gso.CsumOffset } - if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { + if gso.Type != stack.GSONone && uint16(pkt.Data().Size()) > gso.MSS { switch gso.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 @@ -468,7 +468,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset } - if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data.Size()) > pkt.GSOOptions.MSS { + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { switch pkt.GSOOptions.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 358a030d2..1e40f3fef 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -67,7 +67,7 @@ func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { LinkHeader: pk.LinkHeader().View(), NetworkHeader: pk.NetworkHeader().View(), TransportHeader: pk.TransportHeader().View(), - Data: pk.Data.ToView(), + Data: pk.Data().AsRange().ToOwnedView(), } }), ); diff != "" { @@ -616,8 +616,8 @@ func TestDispatchPacketFormat(t *testing.T) { if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) } - if got, want := pkt.Data.Size(), 4; got != want { - t.Errorf("pkt.Data.Size() = %d, want %d", got, want) + if got, want := pkt.Data().Size(), 4; got != want { + t.Errorf("pkt.Data().Size() = %d, want %d", got, want) } }) } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 736871d1c..46df87f44 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -165,7 +165,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. // IP version information is at the first octet, so pulling up 1 byte. - h, ok := pkt.Data.PullUp(1) + h, ok := pkt.Data().PullUp(1) if !ok { return true, nil } @@ -270,7 +270,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. // IP version information is at the first octet, so pulling up 1 byte. - h, ok := pkt.Data.PullUp(1) + h, ok := pkt.Data().PullUp(1) if !ok { // Skip this packet. continue diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index def47772f..d4b3ddd5c 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -80,7 +80,7 @@ func (q *queueBuffers) cleanup() { type packetInfo struct { addr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView + data buffer.View linkHeader buffer.View } @@ -136,7 +136,7 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, proto: proto, - vv: pkt.Data.Clone(nil), + data: pkt.Data().AsRange().ToOwnedView(), }) c.mu.Unlock() @@ -676,7 +676,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].data) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index bd2b8d4bf..7aaee3d13 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -290,7 +290,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + hdr, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { break } @@ -327,7 +327,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize) if !ok { break } @@ -387,7 +387,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments { + if size := pkt.Data().Size() + len(tcp); offset > size && !moreFragments { details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size) break } @@ -398,13 +398,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe // Initialize the TCP flags. flags := tcp.Flags() - flagsStr := []byte("FSRPAU") - for i := range flagsStr { - if flags&(1<<uint(i)) == 0 { - flagsStr[i] = ' ' - } - } - details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) + details = fmt.Sprintf("flags: %s seqnum: %d ack: %d win: %d xsum:0x%x", flags, tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) if flags&header.TCPFlagSyn != 0 { details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0)) } else { diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 3829ca9c9..c1678c4f4 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -281,7 +281,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { vv.AppendView(info.Pkt.NetworkHeader().View()) vv.AppendView(info.Pkt.TransportHeader().View()) // Append data payload. - vv.Append(info.Pkt.Data) + vv.Append(info.Pkt.Data().ExtractVV()) return vv.ToView(), true } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 3fcdea119..ae0461a6d 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -232,7 +232,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.mu.Lock() - e.mu.dad.StopLocked(addr, false /* aborted */) + e.mu.dad.StopLocked(addr, &stack.DADDupAddrDetected{HolderLinkAddress: linkAddr}) e.mu.Unlock() // The solicited, override, and isRouter flags are not available for ARP; diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go index 243738951..5168f5361 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go @@ -170,7 +170,7 @@ func (f *Fragmentation) Process( return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } - if l := pkt.Data.Size(); l != int(fragmentSize) { + if l := pkt.Data().Size(); l != int(fragmentSize) { return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } @@ -293,7 +293,7 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, re // these headers. var fragmentableData buffer.VectorisedView fragmentableData.AppendView(pkt.TransportHeader().View()) - fragmentableData.Append(pkt.Data) + fragmentableData.Append(pkt.Data().ExtractVV()) fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen return PacketFragmenter{ @@ -323,7 +323,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, }) // Copy data for the fragment. - copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) + copied := fragPkt.Data().ReadFromVV(&pf.data, pf.fragmentPayloadLen) offset := pf.fragmentOffset pf.fragmentOffset += copied diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 47ea3173e..7daf64b4a 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -121,7 +121,7 @@ func TestFragmentationProcess(t *testing.T) { in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) } if c.out[i].done { - if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" { + if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" { t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) } @@ -470,9 +470,7 @@ func TestPacketFragmenter(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) - var originalPayload buffer.VectorisedView - originalPayload.AppendView(pkt.TransportHeader().View()) - originalPayload.Append(pkt.Data) + originalPayload := stack.PayloadSince(pkt.TransportHeader()) var reassembledPayload buffer.VectorisedView pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) for i := 0; ; i++ { @@ -499,7 +497,7 @@ func TestPacketFragmenter(t *testing.T) { if got := fragPkt.TransportHeader().View().Size(); got != 0 { t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) } - reassembledPayload.Append(fragPkt.Data) + reassembledPayload.AppendViews(fragPkt.Data().Views()) if !more { if i != len(test.wantFragments)-1 { t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) @@ -507,7 +505,7 @@ func TestPacketFragmenter(t *testing.T) { break } } - if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { + if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" { t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } }) @@ -625,11 +623,11 @@ func TestTimeoutHandler(t *testing.T) { } switch { case handler.pkt != nil && test.wantPkt == nil: - t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) + t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView()) case handler.pkt == nil && test.wantPkt != nil: - t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) + t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView()) case handler.pkt != nil && test.wantPkt != nil: - if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { + if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" { t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) } } diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 933d63d32..90075a70c 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -167,8 +167,8 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s resPkt := r.holes[0].pkt for i := 1; i < len(r.holes); i++ { - fragPkt := r.holes[i].pkt - fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size()) + fragData := r.holes[i].pkt.Data() + resPkt.Data().ReadFromData(fragData, fragData.Size()) } return resPkt, r.proto, true, memConsumed, nil } diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go index 214a93709..cfd9f00ef 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go @@ -204,7 +204,7 @@ func TestReassemblerProcess(t *testing.T) { if a == nil || b == nil { return a == b } - return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView()) + return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView()) } if isDone { diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index 6f89a6a16..0053646ee 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -126,9 +126,12 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet s.timer.Stop() delete(d.addresses, addr) - r := stack.DADResult{Resolved: dadDone, Err: err} + var res stack.DADResult = &stack.DADSucceeded{} + if err != nil { + res = &stack.DADError{Err: err} + } for _, h := range s.completionHandlers { - h(r) + h(res) } }), } @@ -142,7 +145,7 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet // StopLocked stops a currently running DAD process. // // Precondition: d.protocolMU must be locked. -func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) { +func (d *DAD) StopLocked(addr tcpip.Address, reason stack.DADResult) { s, ok := d.addresses[addr] if !ok { return @@ -152,14 +155,8 @@ func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) { s.timer.Stop() delete(d.addresses, addr) - var err tcpip.Error - if aborted { - err = &tcpip.ErrAborted{} - } - - r := stack.DADResult{Resolved: false, Err: err} for _, h := range s.completionHandlers { - h(r) + h(reason) } } diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index 18c357b56..e00aa4678 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -78,10 +78,10 @@ func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADC return m.mu.dad.CheckDuplicateAddressLocked(addr, h) } -func (m *mockDADProtocol) stop(addr tcpip.Address, aborted bool) { +func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) { m.mu.Lock() defer m.mu.Unlock() - m.mu.dad.StopLocked(addr, aborted) + m.mu.dad.StopLocked(addr, reason) } func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { @@ -175,7 +175,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { } clock.Advance(delta) for i := 0; i < 2; i++ { - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff) } } @@ -189,7 +189,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { default: } clock.Advance(delta) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } @@ -202,7 +202,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } clock.Advance(dadConfig2Duration) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } @@ -241,19 +241,19 @@ func TestDADStop(t *testing.T) { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } - dad.stop(addr1, true /* aborted */) - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: false, Err: &tcpip.ErrAborted{}}}, <-ch); diff != "" { + dad.stop(addr1, &stack.DADAborted{}) + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } - dad.stop(addr2, false /* aborted */) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: false, Err: nil}}, <-ch); diff != "" { + dad.stop(addr2, &stack.DADDupAddrDetected{}) + if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr3, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } @@ -266,7 +266,7 @@ func TestDADStop(t *testing.T) { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 5f7e60c5c..b6f39ddb1 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -69,8 +69,8 @@ type MultiCounterIPStats struct { IPTablesOutputDropped tcpip.MultiCounterStat // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out - // of IPStats. + // OptionTimestampReceived is the number of Timestamp options seen. OptionTimestampReceived tcpip.MultiCounterStat diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 90236ed9e..aee1652fa 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -90,8 +90,7 @@ type testObject struct { // checkValues verifies that the transport protocol, data contents, src & dst // addresses of a packet match what's expected. If any field doesn't match, the // test fails. -func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { - v := vv.ToView() +func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) { if protocol != t.protocol { t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) } @@ -120,7 +119,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // parsing are expected. func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { netHdr := pkt.Network() - t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress()) + t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress()) t.dataCalls++ return stack.TransportPacketHandled } @@ -129,7 +128,7 @@ func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumb // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { - t.checkValues(trans, pkt.Data, remote, local) + t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local) if diff := cmp.Diff( t.transErr, transportError{ @@ -198,7 +197,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, pkt.Data, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr) return nil } @@ -371,7 +370,11 @@ func TestSourceAddressValidation(t *testing.T) { pkt.SetType(header.ICMPv6EchoRequest) pkt.SetCode(0) pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: src, + Dst: localIPv6Addr, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, @@ -1199,7 +1202,11 @@ func TestIPv6ReceiveControl(t *testing.T) { nic.testObject.transErr = c.transErr // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: outerSrcAddr, + Dst: localIPv6Addr, + })) addressableEndpoint, ok := ep.(stack.AddressableEndpoint) if !ok { diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 4b21ee79c..5e7f10f4b 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -32,12 +32,14 @@ go_test( "ipv4_test.go", ], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/internal/testutil", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index bd0eabad1..deb104837 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -137,7 +137,7 @@ func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool { // is used to find out which transport endpoint must be notified about the ICMP // packet. We only expect the payload, not the enclosing ICMP packet. func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) if !ok { return } @@ -156,7 +156,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet } hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + if pkt.Data().Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -164,7 +164,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet } // Skip the ip header, then deliver the error. - pkt.Data.TrimFront(hlen) + pkt.Data().TrimFront(hlen) p := hdr.TransportProtocol() e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) } @@ -174,7 +174,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { received.invalid.Increment() return @@ -182,7 +182,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. - if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { + if pkt.Data().AsRange().Checksum() != 0xffff { received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because @@ -238,7 +238,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: - received.echo.Increment() + received.echoRequest.Increment() sent := e.stats.icmp.packetsSent if !e.protocol.stack.AllowICMPMessage() { @@ -253,7 +253,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. - replyData := pkt.Data.ToOwnedView() + replyData := pkt.Data().AsRange().ToOwnedView() ipHdr := header.IPv4(pkt.NetworkHeader().View()) localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast @@ -336,7 +336,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4DstUnreachable: received.dstUnreachable.Increment() - pkt.Data.TrimFront(header.ICMPv4MinimumSize) + pkt.Data().TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) @@ -571,7 +571,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return nil } - payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size() + payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data().Size() if payloadLen > available { payloadLen = available } @@ -586,8 +586,11 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip newHeader := append(buffer.View(nil), origIPHdr...) newHeader = append(newHeader, transportHeader...) payload := newHeader.ToVectorisedView() - payload.AppendView(pkt.Data.ToView()) - payload.CapLength(payloadLen) + if dataCap := payloadLen - payload.Size(); dataCap > 0 { + payload.AppendView(pkt.Data().AsRange().Capped(dataCap).ToOwnedView()) + } else { + payload.CapLength(payloadLen) + } icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize, @@ -623,7 +626,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( nil, /* gso */ diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 0a15ae897..f3fc1c87e 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -197,7 +197,7 @@ func (igmp *igmpState) isPacketValidLocked(pkt *stack.PacketBuffer, messageType // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer, hasRouterAlertOption bool) { received := igmp.ep.stats.igmp.packetsReceived - headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) + headerView, ok := pkt.Data().PullUp(header.IGMPMinimumSize) if !ok { received.invalid.Increment() return @@ -210,7 +210,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer, hasRouterAlertOption // same set of octets, including the checksum field. If the result // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. - if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xFFFF { + if pkt.Data().AsRange().Checksum() != 0xFFFF { received.checksumErrors.Increment() return } diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index c5f68e411..e5e1b89cc 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -106,9 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma igmp.SetGroupAddress(groupAddress) igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } // TestIGMPV1Present tests the node's ability to fallback to V1 when a V1 diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 4a429ea6c..8a2140ebe 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -492,7 +492,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) if !ok { return &tcpip.ErrMalformedHeader{} } @@ -502,14 +502,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return &tcpip.ErrMalformedHeader{} } - h, ok = pkt.Data.PullUp(int(hdrLen)) + h, ok = pkt.Data().PullUp(int(hdrLen)) if !ok { return &tcpip.ErrMalformedHeader{} } ip := header.IPv4(h) // Always set the total length. - pktSize := pkt.Data.Size() + pktSize := pkt.Data().Size() ip.SetTotalLength(uint16(pktSize)) // Set the source address when zero. @@ -687,7 +687,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { stats := e.stats h := header.IPv4(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { + if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { stats.ip.MalformedPacketsReceived.Increment() return } @@ -765,7 +765,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } if h.More() || h.FragmentOffset() != 0 { - if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { + if pkt.Data().Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. stats.ip.MalformedPacketsReceived.Increment() @@ -793,10 +793,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // maximum payload size. // // Note that this addition doesn't overflow even on 32bit architecture - // because pkt.Data.Size() should not exceed 65535 (the max IP datagram + // because pkt.Data().Size() should not exceed 65535 (the max IP datagram // size). Otherwise the packet would've been rejected as invalid before // reaching here. - if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { + if int(start)+pkt.Data().Size() > header.IPv4MaximumPayloadSize { stats.ip.MalformedPacketsReceived.Increment() stats.ip.MalformedFragmentsReceived.Increment() return @@ -813,7 +813,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { Protocol: proto, }, start, - start+uint16(pkt.Data.Size())-1, + start+uint16(pkt.Data().Size())-1, h.More(), proto, pkt, @@ -831,7 +831,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // The reassembler doesn't take care of fixing up the header, so we need // to do it here. - h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) + h.SetTotalLength(uint16(pkt.Data().Size() + len((h)))) h.SetFlagsFragmentOffset(0, 0) } stats.ip.PacketsDelivered.Increment() @@ -899,10 +899,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() { e.mu.Lock() - defer e.mu.Unlock() - e.disableLocked() e.mu.addressableEndpointState.Cleanup() + e.mu.Unlock() e.protocol.forgetEndpoint(e.nic.ID()) } @@ -1186,7 +1185,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error } func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { - payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + payload := pkt.TransportHeader().View().Size() + pkt.Data().Size() return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index dc4db6e5f..cfed241bf 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -26,12 +26,14 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" @@ -1211,7 +1213,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) } reassembledPayload.AppendView(packet.TransportHeader().View()) - reassembledPayload.Append(packet.Data) + reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView()) // Clear out the checksum and length from the ip because we can't compare // it. sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) @@ -2985,3 +2987,120 @@ func TestPacketQueing(t *testing.T) { }) } } + +// TestCloseLocking test that lock ordering is followed when closing an +// endpoint. +func TestCloseLocking(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + src = tcpip.Address("\x10\x00\x00\x01") + dst = tcpip.Address("\x10\x00\x00\x02") + + iterations = 1000 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + // Perform NAT so that the endoint tries to search for a sibling endpoint + // which ends up taking the protocol and endpoint lock (in that order). + table := stack.Table{ + Rules: []stack.Rule{ + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 1, + stack.Forward: stack.HookUnset, + stack.Output: 2, + stack.Postrouting: 3, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 1, + stack.Forward: stack.HookUnset, + stack.Output: 2, + stack.Postrouting: 3, + }, + } + if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil { + t.Fatalf("s.IPTables().ReplaceTable(...): %s", err) + } + + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }}) + + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53} + if err := ep.Connect(addr); err != nil { + t.Errorf("ep.Connect(%#v): %s", addr, err) + } + + var wg sync.WaitGroup + defer wg.Wait() + + // Writing packets should trigger NAT which requires the stack to search the + // protocol for network endpoints with the destination address. + // + // Creating and removing interfaces should modify the protocol and endpoint + // which requires taking the locks of each. + // + // We expect the protocol > endpoint lock ordering to be followed here. + wg.Add(2) + go func() { + defer wg.Done() + + data := []byte{1, 2, 3, 4} + + for i := 0; i < iterations; i++ { + var r bytes.Reader + r.Reset(data) + if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Errorf("ep.Write(_, _): %s", err) + return + } else if want := int64(len(data)); n != want { + t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) + return + } + } + }() + go func() { + defer wg.Done() + + for i := 0; i < iterations; i++ { + if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + t.Errorf("CreateNIC(%d, _): %s", nicID2, err) + return + } + if err := s.RemoveNIC(nicID2); err != nil { + t.Errorf("RemoveNIC(%d): %s", nicID2, err) + return + } + } + }() +} diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go index 5ae73fbfb..5798cfec6 100644 --- a/pkg/tcpip/network/ipv4/stats.go +++ b/pkg/tcpip/network/ipv4/stats.go @@ -52,7 +52,7 @@ type sharedStats struct { // LINT.IfChange(multiCounterICMPv4PacketStats) type multiCounterICMPv4PacketStats struct { - echo tcpip.MultiCounterStat + echoRequest tcpip.MultiCounterStat echoReply tcpip.MultiCounterStat dstUnreachable tcpip.MultiCounterStat srcQuench tcpip.MultiCounterStat @@ -66,7 +66,7 @@ type multiCounterICMPv4PacketStats struct { } func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) { - m.echo.Init(a.Echo, b.Echo) + m.echoRequest.Init(a.EchoRequest, b.EchoRequest) m.echoReply.Init(a.EchoReply, b.EchoReply) m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable) m.srcQuench.Init(a.SrcQuench, b.SrcQuench) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 5f44ab317..6344a3e09 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -18,7 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -165,7 +164,7 @@ func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool { // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + h, ok := pkt.Data().PullUp(header.IPv6MinimumSize) if !ok { return } @@ -184,10 +183,10 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data().TrimFront(header.IPv6MinimumSize) p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { return } @@ -200,7 +199,7 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip fragmentation header and find out the actual protocol // number. - pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data().TrimFront(header.IPv6FragmentHeaderSize) p = fragHdr.TransportProtocol() } @@ -268,7 +267,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if routerAlert == nil || routerAlert.Value != header.IPv6RouterAlertMLD { return false } - if pkt.Data.Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { + if pkt.Data().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { return false } if iph.HopLimit() != header.MLDHopLimit { @@ -285,7 +284,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set. See icmp/protocol.go:protocol.Parse for a full explanation. - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) if !ok { received.invalid.Increment() return @@ -296,11 +295,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. - // - // This copy is used as extra payload during the checksum calculation. - payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) - if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { + payload := pkt.Data().AsRange().SubRange(len(h)) + if got, want := h.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: h, + Src: srcAddr, + Dst: dstAddr, + PayloadCsum: payload.Checksum(), + PayloadLen: payload.Size(), + }); got != want { received.invalid.Increment() return } @@ -320,12 +322,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.packetTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize) if !ok { received.invalid.Increment() return } - pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize) networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 @@ -334,12 +336,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize) if !ok { received.invalid.Increment() return } - pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) @@ -348,16 +350,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } case header.ICMPv6NeighborSolicit: received.neighborSolicit.Increment() - if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { + if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborSolicitMinimumSize { received.invalid.Increment() return } // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // payload.AsView() always returns the solicitation. Per RFC 6980 section 5, // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + // datagrams are very small and AsView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.AsView()) targetAddr := ns.TargetAddress() // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast @@ -380,6 +382,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // stack know so it can handle such a scenario and do nothing further with // the NS. if srcAddr == header.IPv6Any { + // Since this is a DAD message we know the sender does not actually hold + // the target address so there is no "holder". + var holderLinkAddress tcpip.LinkAddress + // We would get an error if the address no longer exists or the address // is no longer tentative (DAD resolved between the call to // hasTentativeAddr and this point). Both of these are valid scenarios: @@ -391,7 +397,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) @@ -529,7 +535,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) na.Options().Serialize(optsSerializer) - packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: packet, + Src: r.LocalAddress, + Dst: r.RemoteAddress, + })) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) // @@ -545,20 +555,34 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6NeighborAdvert: received.neighborAdvert.Increment() - if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize { + if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborAdvertMinimumSize { received.invalid.Increment() return } // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section + // payload.AsView() always returns the advertisement. Per RFC 6980 section // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + // NDP datagrams are very small and AsView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.AsView()) + + it, err := na.Options().Iter(false /* check */) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. + received.invalid.Increment() + return + } + + targetLinkAddr, ok := getTargetLinkAddr(it) + if !ok { + received.invalid.Increment() + return + } + targetAddr := na.TargetAddress() e.dad.mu.Lock() - e.dad.mu.dad.StopLocked(targetAddr, false /* aborted */) + e.dad.mu.dad.StopLocked(targetAddr, &stack.DADDupAddrDetected{HolderLinkAddress: targetLinkAddr}) e.dad.mu.Unlock() if e.hasTentativeAddr(targetAddr) { @@ -578,7 +602,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: return default: @@ -586,13 +610,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } } - it, err := na.Options().Iter(false /* check */) - if err != nil { - // If we have a malformed NDP NA option, drop the packet. - received.invalid.Increment() - return - } - // At this point we know that the target address is not tentative on the // NIC. However, the target address may still be assigned to the NIC but not // tentative (it could be permanent). Such a scenario is beyond the scope of @@ -602,11 +619,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // TODO(b/143147598): Handle the scenario described above. Also inform the // netstack integration that a duplicate address was detected outside of // DAD. - targetLinkAddr, ok := getTargetLinkAddr(it) - if !ok { - received.invalid.Increment() - return - } // As per RFC 4861 section 7.1.2: // A node MUST silently discard any received Neighbor Advertisement @@ -657,13 +669,20 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, - Data: pkt.Data, + Data: pkt.Data().ExtractVV(), }) - packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) + icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - copy(packet, icmpHdr) - packet.SetType(header.ICMPv6EchoReply) - packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + copy(icmp, icmpHdr) + icmp.SetType(header.ICMPv6EchoReply) + dataRange := replyPkt.Data().AsRange() + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: r.LocalAddress, + Dst: r.RemoteAddress, + PayloadCsum: dataRange.Checksum(), + PayloadLen: dataRange.Size(), + })) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), @@ -676,7 +695,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoReply: received.echoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if pkt.Data().Size() < header.ICMPv6EchoMinimumSize { received.invalid.Increment() return } @@ -696,7 +715,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Solictation? - if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { + if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { received.invalid.Increment() return } @@ -710,9 +729,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and ToView() + // Note that in the common case NDP datagrams are very small and AsView() // will not incur allocations. - rs := header.NDPRouterSolicit(payload.ToView()) + rs := header.NDPRouterSolicit(payload.AsView()) it, err := rs.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -756,7 +775,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Advertisement? - if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { + if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { received.invalid.Increment() return } @@ -770,9 +789,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and ToView() + // Note that in the common case NDP datagrams are very small and AsView() // will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(payload.AsView()) it, err := ra.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -850,11 +869,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType { case header.ICMPv6MulticastListenerQuery: e.mu.Lock() - e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.AsView())) e.mu.Unlock() case header.ICMPv6MulticastListenerReport: e.mu.Lock() - e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.AsView())) e.mu.Unlock() case header.ICMPv6MulticastListenerDone: default: @@ -1077,13 +1096,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip if available < header.IPv6MinimumSize { return nil } - payloadLen := network.Size() + transport.Size() + pkt.Data.Size() + payloadLen := network.Size() + transport.Size() + pkt.Data().Size() if payloadLen > available { payloadLen = available } payload := network.ToVectorisedView() payload.AppendView(transport) - payload.Append(pkt.Data) + payload.Append(pkt.Data().ExtractVV()) payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1115,7 +1134,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data)) + dataRange := newPkt.Data().AsRange() + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: route.LocalAddress, + Dst: route.RemoteAddress, + PayloadCsum: dataRange.Checksum(), + PayloadLen: dataRange.Size(), + })) if err := route.WritePacket( nil, /* gso */ stack.NetworkHeaderParams{ diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index c27164344..d4e63710c 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -324,7 +324,13 @@ func TestICMPCounts(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp[:typ.size], + Src: lladdr0, + Dst: lladdr1, + PayloadCsum: header.Checksum(typ.extraData, 0 /* initial */), + PayloadLen: len(typ.extraData), + })) handleICMPInIPv6(ep, lladdr1, lladdr0, icmp, typ.hopLimit, typ.includeRouterAlert) } @@ -498,7 +504,11 @@ func TestLinkResolution(t *testing.T) { hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: r.LocalAddress, + Dst: r.RemoteAddress, + })) // We can't send our payload directly over the route because that // doesn't provoke NDP discovery. @@ -687,7 +697,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: lladdr1, + Dst: lladdr0, + })) } ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -879,7 +893,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { payloadFn(icmpHdr.Payload()) if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: lladdr1, + Dst: lladdr0, + })) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1058,7 +1076,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { payloadFn(payload) if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: lladdr1, + Dst: lladdr0, + PayloadCsum: header.Checksum(payload, 0 /* initial */), + PayloadLen: len(payload), + })) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1324,7 +1348,11 @@ func TestPacketQueing(t *testing.T) { pkt.SetType(header.ICMPv6EchoRequest) pkt.SetCode(0) pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: host2IPv6Addr.AddressWithPrefix.Address, + Dst: host1IPv6Addr.AddressWithPrefix.Address, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, @@ -1422,7 +1450,11 @@ func TestPacketQueing(t *testing.T) { na.Options().Serialize(header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: host2IPv6Addr.AddressWithPrefix.Address, + Dst: host1IPv6Addr.AddressWithPrefix.Address, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -1661,7 +1693,11 @@ func TestCallsToNeighborCache(t *testing.T) { } icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: test.source, + Dst: test.destination, + })) handleICMPInIPv6(ep, test.source, test.destination, icmp, header.NDPHopLimit, false) // Confirm the endpoint calls the correct NUDHandler method. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7638ade35..46b6cc41a 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { // dupTentativeAddrDetected removes the tentative address if it exists. If the // address was generated via SLAAC, an attempt is made to generate a new // address. -func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error { +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -363,7 +363,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error { // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an // attempt will be made to generate a new address for it. - if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, true /* dadFailure */); err != nil { + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil { return err } @@ -536,8 +536,20 @@ func (e *endpoint) disableLocked() { } e.mu.ndp.stopSolicitingRouters() + // Stop DAD for all the tentative unicast addresses. + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.GetKind() != stack.PermanentTentative { + return true + } + + addr := addressEndpoint.AddressWithPrefix().Address + if header.IsV6UnicastAddress(addr) { + e.mu.ndp.stopDuplicateAddressDetection(addr, &stack.DADAborted{}) + } + + return true + }) e.mu.ndp.cleanupState(false /* hostOnly */) - e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { @@ -555,25 +567,6 @@ func (e *endpoint) disableLocked() { } } -// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. -// -// Precondition: e.mu must be write locked. -func (e *endpoint) stopDADForPermanentAddressesLocked() { - // Stop DAD for all the tentative unicast addresses. - e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { - if addressEndpoint.GetKind() != stack.PermanentTentative { - return true - } - - addr := addressEndpoint.AddressWithPrefix().Address - if header.IsV6UnicastAddress(addr) { - e.mu.ndp.stopDuplicateAddressDetection(addr, false /* failed */) - } - - return true - }) -} - // DefaultTTL is the default hop limit for this endpoint. func (e *endpoint) DefaultTTL() uint8 { return e.protocol.DefaultTTL() @@ -619,7 +612,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params } func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { - payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + payload := pkt.TransportHeader().View().Size() + pkt.Data().Size() return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } @@ -819,14 +812,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required checks. - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + h, ok := pkt.Data().PullUp(header.IPv6MinimumSize) if !ok { return &tcpip.ErrMalformedHeader{} } ip := header.IPv6(h) // Always set the payload length. - pktSize := pkt.Data.Size() + pktSize := pkt.Data().Size() ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) // Set the source address when zero. @@ -964,7 +957,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { stats := e.stats.ip h := header.IPv6(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { + if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { stats.MalformedPacketsReceived.Increment() return } @@ -993,13 +986,14 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } + // Create a VV to parse the packet. We don't plan to modify anything here. // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() vv.AppendView(pkt.TransportHeader().View()) - vv.Append(pkt.Data) + vv.AppendViews(pkt.Data().Views()) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) // iptables filtering. All packets that reach here are intended for @@ -1257,7 +1251,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // have more extension headers in the reassembled payload, as per RFC // 8200 section 4.5. We also use the NextHeader value from the first // fragment. - it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data) + data := pkt.Data() + dataVV := buffer.NewVectorisedView(data.Size(), data.Views()) + it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), dataVV) } case header.IPv6DestinationOptionsExtHdr: @@ -1314,7 +1310,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) - pkt.Data = extHdr.Buf + pkt.Data().Replace(extHdr.Buf) stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -1381,8 +1377,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { func (e *endpoint) Close() { e.mu.Lock() e.disableLocked() - e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */) - e.stopDADForPermanentAddressesLocked() e.mu.addressableEndpointState.Cleanup() e.mu.Unlock() @@ -1448,14 +1442,14 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } - return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, false /* dadFailure */) + return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, &stack.DADAborted{}) } // removePermanentEndpointLocked is like removePermanentAddressLocked except // it works with a stack.AddressEndpoint. // // Precondition: e.mu must be write locked. -func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation, dadFailure bool) tcpip.Error { +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool, dadResult stack.DADResult) tcpip.Error { addr := addressEndpoint.AddressWithPrefix() // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. @@ -1466,16 +1460,16 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr) } - return e.removePermanentEndpointInnerLocked(addressEndpoint, dadFailure) + return e.removePermanentEndpointInnerLocked(addressEndpoint, dadResult) } // removePermanentEndpointInnerLocked is like removePermanentEndpointLocked // except it does not cleanup SLAAC address state. // // Precondition: e.mu must be write locked. -func (e *endpoint) removePermanentEndpointInnerLocked(addressEndpoint stack.AddressEndpoint, dadFailure bool) tcpip.Error { +func (e *endpoint) removePermanentEndpointInnerLocked(addressEndpoint stack.AddressEndpoint, dadResult stack.DADResult) tcpip.Error { addr := addressEndpoint.AddressWithPrefix() - e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadFailure) + e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadResult) if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil { return err diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 05c9f4dbf..266a53e3b 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -68,7 +68,11 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: src, + Dst: dst, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -216,7 +220,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB // Store the reassembled payload as we parse each fragment. The payload // includes the Transport header and everything after. reassembledPayload.AppendView(fragment.TransportHeader().View()) - reassembledPayload.Append(fragment.Data) + reassembledPayload.AppendView(fragment.Data().AsRange().ToOwnedView()) } if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { @@ -3065,7 +3069,11 @@ func TestForwarding(t *testing.T) { icmp.SetType(header.ICMPv6EchoRequest) icmp.SetCode(header.ICMPv6UnusedCode) icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: remoteIPv6Addr1, + Dst: remoteIPv6Addr2, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 205e36cdd..dd153466d 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -236,7 +236,11 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp localAddress = header.IPv6Any } - icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: localAddress, + Dst: destAddress, + })) extensionHeaders := header.IPv6ExtHdrSerializer{ header.IPv6SerializableHopByHopExtHdr{ diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index f1b8d58f2..9a425e50a 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -326,11 +326,15 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, ho mld := header.MLD(icmp.MessageBody()) mld.SetMaximumResponseDelay(0) mld.SetMulticastAddress(header.IPv6Any) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, srcAddress, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: srcAddress, + Dst: header.IPv6AllNodesMulticastAddress, + })) - e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } func TestMLDPacketValidation(t *testing.T) { diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 721269c58..d9b728878 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -208,16 +208,12 @@ const ( // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { - // OnDuplicateAddressDetectionStatus is called when the DAD process for an - // address (addr) on a NIC (with ID nicID) completes. resolved is set to true - // if DAD completed successfully (no duplicate addr detected); false otherwise - // (addr was detected to be a duplicate on the link the NIC is a part of, or - // it was stopped for some other reason, such as the address being removed). - // If an error occured during DAD, err is set and resolved must be ignored. + // OnDuplicateAddressDetectionResult is called when the DAD process for an + // address on a NIC completes. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) + OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) // OnDefaultRouterDiscovered is called when a new default router is // discovered. Implementations must return true if the newly discovered @@ -225,14 +221,14 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool + OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool // OnDefaultRouterInvalidated is called when a discovered default router that // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) + OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. // Implementations must return true if the newly discovered on-link prefix @@ -240,14 +236,14 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool + OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) + OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) // OnAutoGenAddress is called when a new prefix with its autonomous address- // configuration flag set is received and SLAAC was performed. Implementations @@ -280,12 +276,12 @@ type NDPDispatcher interface { // It is up to the caller to use the DNS Servers only for their valid // lifetime. OnRecursiveDNSServerOption may be called for new or // already known DNS servers. If called with known DNS servers, their - // valid lifetimes must be refreshed to lifetime (it may be increased, - // decreased, or completely invalidated when lifetime = 0). + // valid lifetimes must be refreshed to the lifetime (it may be increased, + // decreased, or completely invalidated when the lifetime = 0). // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. - OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) // OnDNSSearchListOption is called when the stack learns of DNS search lists // through NDP. @@ -293,9 +289,9 @@ type NDPDispatcher interface { // It is up to the caller to use the domain names in the search list // for only their valid lifetime. OnDNSSearchListOption may be called // with new or already known domain names. If called with known domain - // names, their valid lifetimes must be refreshed to lifetime (it may - // be increased, decreased or completely invalidated when lifetime = 0. - OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) + // names, their valid lifetimes must be refreshed to the lifetime (it may + // be increased, decreased or completely invalidated when the lifetime = 0. + OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) // OnDHCPv6Configuration is called with an updated configuration that is // available via DHCPv6 for the passed NIC. @@ -587,15 +583,25 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } - if r.Resolved { + var dadSucceeded bool + switch r.(type) { + case *stack.DADAborted, *stack.DADError, *stack.DADDupAddrDetected: + dadSucceeded = false + case *stack.DADSucceeded: + dadSucceeded = true + default: + panic(fmt.Sprintf("unrecognized DAD result = %T", r)) + } + + if dadSucceeded { addressEndpoint.SetKind(stack.Permanent) } if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, r.Resolved, r.Err) + ndpDisp.OnDuplicateAddressDetectionResult(ndp.ep.nic.ID(), addr, r) } - if r.Resolved { + if dadSucceeded { if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { // Reset the generation attempts counter as we are starting the // generation of a new address for the SLAAC prefix. @@ -616,7 +622,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // Consider DAD to have resolved even if no DAD messages were actually // transmitted. if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) + ndpDisp.OnDuplicateAddressDetectionResult(ndp.ep.nic.ID(), addr, &stack.DADSucceeded{}) } ndp.ep.onAddressAssignedLocked(addr) @@ -633,8 +639,8 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // of this function to handle such a scenario. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, failed bool) { - ndp.dad.StopLocked(addr, !failed) +func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, reason stack.DADResult) { + ndp.dad.StopLocked(addr, reason) } // handleRA handles a Router Advertisement message that arrived on the NIC @@ -1501,7 +1507,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } - if err := ndp.ep.removePermanentEndpointInnerLocked(addressEndpoint, false /* dadFailure */); err != nil { + if err := ndp.ep.removePermanentEndpointInnerLocked(addressEndpoint, &stack.DADAborted{}); err != nil { panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err)) } } @@ -1560,7 +1566,7 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { ndp.cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs, tempAddr, tempAddrState) - if err := ndp.ep.removePermanentEndpointInnerLocked(tempAddrState.addressEndpoint, false /* dadFailure */); err != nil { + if err := ndp.ep.removePermanentEndpointInnerLocked(tempAddrState.addressEndpoint, &stack.DADAborted{}); err != nil { panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err)) } } @@ -1721,7 +1727,11 @@ func (ndp *ndpState) startSolicitingRouters() { icmpData.SetType(header.ICMPv6RouterSolicit) rs := header.NDPRouterSolicit(icmpData.MessageBody()) rs.Options().Serialize(optsSerializer) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpData, + Src: localAddr, + Dst: header.IPv6AllRoutersMulticastAddress, + })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), @@ -1812,7 +1822,11 @@ func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteL ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(opts) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, srcAddr, dstAddr, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: srcAddr, + Dst: dstAddr, + })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.MaxHeaderLength()), diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index ce20af0e3..6e850fd46 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -90,7 +90,7 @@ type testNDPDispatcher struct { addr tcpip.Address } -func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { +func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { @@ -215,7 +215,11 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: lladdr1, + Dst: lladdr0, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -478,7 +482,11 @@ func TestNeighborSolicitationResponse(t *testing.T) { ns.SetTargetAddress(nicAddr) opts := ns.Options() opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: test.nsSrc, + Dst: test.nsDst, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -554,7 +562,11 @@ func TestNeighborSolicitationResponse(t *testing.T) { na.SetOverrideFlag(true) na.SetTargetAddress(test.nsSrc) na.Options().Serialize(ser) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: test.nsSrc, + Dst: nicAddr, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -657,7 +669,11 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: lladdr1, + Dst: lladdr0, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -874,7 +890,13 @@ func TestNDPValidation(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp[:typ.size], + Src: lladdr0, + Dst: lladdr1, + PayloadCsum: header.Checksum(typ.extraData /* initial */, 0), + PayloadLen: len(typ.extraData), + })) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -987,7 +1009,11 @@ func TestNeighborAdvertisementValidation(t *testing.T) { na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetTargetAddress(lladdr1) na.SetSolicitedFlag(test.solicitedFlag) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: lladdr1, + Dst: test.ipDstAddr, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -1182,7 +1208,11 @@ func TestRouterAdvertValidation(t *testing.T) { pkt.SetCode(test.code) copy(pkt.MessageBody(), test.ndpPayload) payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: test.src, + Dst: header.IPv6AllNodesMulticastAddress, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), @@ -1284,10 +1314,10 @@ func TestCheckDuplicateAddress(t *testing.T) { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning) } - // Wait for DAD to resolve. + // Wait for DAD to complete. clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) for i := 0; i < dadRequestsMade; i++ { - if diff := cmp.Diff(stack.DADResult{Resolved: true}, <-ch); diff != "" { + if diff := cmp.Diff(&stack.DADSucceeded{}, <-ch); diff != "" { t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) } } diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 73913aef8..ecd5003a7 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -230,9 +230,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime b igmp.SetGroupAddress(groupAddress) igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } // createAndInjectMLDPacket creates and injects an MLD packet with the @@ -263,11 +263,15 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay b mld := header.MLD(icmp.MessageBody()) mld.SetMaximumResponseDelay(uint16(maxRespDelay)) mld.SetMulticastAddress(groupAddress) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, linkLocalIPv6Addr2, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: linkLocalIPv6Addr2, + Dst: header.IPv6AllNodesMulticastAddress, + })) - e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } // TestMGPDisabled tests that the multicast group protocol is not enabled by diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 57abec5c9..210262703 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "ports", - srcs = ["ports.go"], + srcs = [ + "flags.go", + "ports.go", + ], visibility = ["//visibility:public"], deps = [ "//pkg/sync", diff --git a/pkg/tcpip/ports/flags.go b/pkg/tcpip/ports/flags.go new file mode 100644 index 000000000..a8d7bff25 --- /dev/null +++ b/pkg/tcpip/ports/flags.go @@ -0,0 +1,150 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ports + +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool + + // TupleOnly represents TCP SO_REUSEADDR. + TupleOnly bool +} + +// Bits converts the Flags to their bitset form. +func (f Flags) Bits() BitFlags { + var rf BitFlags + if f.MostRecent { + rf |= MostRecentFlag + } + if f.LoadBalanced { + rf |= LoadBalancedFlag + } + if f.TupleOnly { + rf |= TupleOnlyFlag + } + return rf +} + +// Effective returns the effective behavior of a flag config. +func (f Flags) Effective() Flags { + e := f + if e.LoadBalanced && e.MostRecent { + e.MostRecent = false + } + return e +} + +// BitFlags is a bitset representation of Flags. +type BitFlags uint32 + +const ( + // MostRecentFlag represents Flags.MostRecent. + MostRecentFlag BitFlags = 1 << iota + + // LoadBalancedFlag represents Flags.LoadBalanced. + LoadBalancedFlag + + // TupleOnlyFlag represents Flags.TupleOnly. + TupleOnlyFlag + + // nextFlag is the value that the next added flag will have. + // + // It is used to calculate FlagMask below. It is also the number of + // valid flag states. + nextFlag + + // FlagMask is a bit mask for BitFlags. + FlagMask = nextFlag - 1 + + // MultiBindFlagMask contains the flags that allow binding the same + // tuple multiple times. + MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag +) + +// ToFlags converts the bitset into a Flags struct. +func (f BitFlags) ToFlags() Flags { + return Flags{ + MostRecent: f&MostRecentFlag != 0, + LoadBalanced: f&LoadBalancedFlag != 0, + TupleOnly: f&TupleOnlyFlag != 0, + } +} + +// FlagCounter counts how many references each flag combination has. +type FlagCounter struct { + // refs stores the count for each possible flag combination, (0 though + // FlagMask). + refs [nextFlag]int +} + +// AddRef increases the reference count for a specific flag combination. +func (c *FlagCounter) AddRef(flags BitFlags) { + c.refs[flags]++ +} + +// DropRef decreases the reference count for a specific flag combination. +func (c *FlagCounter) DropRef(flags BitFlags) { + c.refs[flags]-- +} + +// TotalRefs calculates the total number of references for all flag +// combinations. +func (c FlagCounter) TotalRefs() int { + var total int + for _, r := range c.refs { + total += r + } + return total +} + +// FlagRefs returns the number of references with all specified flags. +func (c FlagCounter) FlagRefs(flags BitFlags) int { + var total int + for i, r := range c.refs { + if BitFlags(i)&flags == flags { + total += r + } + } + return total +} + +// AllRefsHave returns if all references have all specified flags. +func (c FlagCounter) AllRefsHave(flags BitFlags) bool { + for i, r := range c.refs { + if BitFlags(i)&flags != flags && r > 0 { + return false + } + } + return true +} + +// SharedFlags returns the set of flags shared by all references. +func (c FlagCounter) SharedFlags() BitFlags { + intersection := FlagMask + for i, r := range c.refs { + if r > 0 { + intersection &= BitFlags(i) + } + } + return intersection +} diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 11dbdbbcf..678199371 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ports provides PortManager that manages allocating, reserving and releasing ports. +// Package ports provides PortManager that manages allocating, reserving and +// releasing ports. package ports import ( - "math" "math/rand" "sync/atomic" @@ -24,169 +24,44 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -const ( - // FirstEphemeral is the first ephemeral port. - FirstEphemeral = 16000 +const anyIPAddress tcpip.Address = "" - // numEphemeralPorts it the mnumber of available ephemeral ports to - // Netstack. - numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1 +// Reservation describes a port reservation. +type Reservation struct { + // Networks is a list of network protocols to which the reservation + // applies. Can be IPv4, IPv6, or both. + Networks []tcpip.NetworkProtocolNumber - anyIPAddress tcpip.Address = "" -) - -type portDescriptor struct { - network tcpip.NetworkProtocolNumber - transport tcpip.TransportProtocolNumber - port uint16 -} - -// Flags represents the type of port reservation. -// -// +stateify savable -type Flags struct { - // MostRecent represents UDP SO_REUSEADDR. - MostRecent bool - - // LoadBalanced indicates SO_REUSEPORT. - // - // LoadBalanced takes precidence over MostRecent. - LoadBalanced bool - - // TupleOnly represents TCP SO_REUSEADDR. - TupleOnly bool -} - -// Bits converts the Flags to their bitset form. -func (f Flags) Bits() BitFlags { - var rf BitFlags - if f.MostRecent { - rf |= MostRecentFlag - } - if f.LoadBalanced { - rf |= LoadBalancedFlag - } - if f.TupleOnly { - rf |= TupleOnlyFlag - } - return rf -} - -// Effective returns the effective behavior of a flag config. -func (f Flags) Effective() Flags { - e := f - if e.LoadBalanced && e.MostRecent { - e.MostRecent = false - } - return e -} - -// PortManager manages allocating, reserving and releasing ports. -type PortManager struct { - mu sync.RWMutex - allocatedPorts map[portDescriptor]bindAddresses - - // hint is used to pick ports ephemeral ports in a stable order for - // a given port offset. - // - // hint must be accessed using the portHint/incPortHint helpers. - // TODO(gvisor.dev/issue/940): S/R this field. - hint uint32 -} - -// BitFlags is a bitset representation of Flags. -type BitFlags uint32 - -const ( - // MostRecentFlag represents Flags.MostRecent. - MostRecentFlag BitFlags = 1 << iota - - // LoadBalancedFlag represents Flags.LoadBalanced. - LoadBalancedFlag - - // TupleOnlyFlag represents Flags.TupleOnly. - TupleOnlyFlag - - // nextFlag is the value that the next added flag will have. - // - // It is used to calculate FlagMask below. It is also the number of - // valid flag states. - nextFlag - - // FlagMask is a bit mask for BitFlags. - FlagMask = nextFlag - 1 + // Transport is the transport protocol to which the reservation applies. + Transport tcpip.TransportProtocolNumber - // MultiBindFlagMask contains the flags that allow binding the same - // tuple multiple times. - MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag -) - -// ToFlags converts the bitset into a Flags struct. -func (f BitFlags) ToFlags() Flags { - return Flags{ - MostRecent: f&MostRecentFlag != 0, - LoadBalanced: f&LoadBalancedFlag != 0, - TupleOnly: f&TupleOnlyFlag != 0, - } -} + // Addr is the address of the local endpoint. + Addr tcpip.Address -// FlagCounter counts how many references each flag combination has. -type FlagCounter struct { - // refs stores the count for each possible flag combination, (0 though - // FlagMask). - refs [nextFlag]int -} + // Port is the local port number. + Port uint16 -// AddRef increases the reference count for a specific flag combination. -func (c *FlagCounter) AddRef(flags BitFlags) { - c.refs[flags]++ -} + // Flags describe features of the reservation. + Flags Flags -// DropRef decreases the reference count for a specific flag combination. -func (c *FlagCounter) DropRef(flags BitFlags) { - c.refs[flags]-- -} + // BindToDevice is the NIC to which the reservation applies. + BindToDevice tcpip.NICID -// TotalRefs calculates the total number of references for all flag -// combinations. -func (c FlagCounter) TotalRefs() int { - var total int - for _, r := range c.refs { - total += r - } - return total + // Dest is the destination address. + Dest tcpip.FullAddress } -// FlagRefs returns the number of references with all specified flags. -func (c FlagCounter) FlagRefs(flags BitFlags) int { - var total int - for i, r := range c.refs { - if BitFlags(i)&flags == flags { - total += r - } - } - return total -} - -// AllRefsHave returns if all references have all specified flags. -func (c FlagCounter) AllRefsHave(flags BitFlags) bool { - for i, r := range c.refs { - if BitFlags(i)&flags != flags && r > 0 { - return false - } +func (rs Reservation) dst() destination { + return destination{ + rs.Dest.Addr, + rs.Dest.Port, } - return true } -// IntersectionRefs returns the set of flags shared by all references. -func (c FlagCounter) IntersectionRefs() BitFlags { - intersection := FlagMask - for i, r := range c.refs { - if r > 0 { - intersection &= BitFlags(i) - } - } - return intersection +type portDescriptor struct { + network tcpip.NetworkProtocolNumber + transport tcpip.TransportProtocolNumber + port uint16 } type destination struct { @@ -194,18 +69,14 @@ type destination struct { port uint16 } -func makeDestination(a tcpip.FullAddress) destination { - return destination{ - a.Addr, - a.Port, - } -} - -// portNode is never empty. When it has no elements, it is removed from the -// map that references it. -type portNode map[destination]FlagCounter +// destToCounter maps each destination to the FlagCounter that represents +// endpoints to that destination. +// +// destToCounter is never empty. When it has no elements, it is removed from +// the map that references it. +type destToCounter map[destination]FlagCounter -// intersectionRefs calculates the intersection of flag bit values which affect +// intersectionFlags calculates the intersection of flag bit values which affect // the specified destination. // // If no destinations are present, all flag values are returned as there are no @@ -213,20 +84,20 @@ type portNode map[destination]FlagCounter // // In addition to the intersection, the number of intersecting refs is // returned. -func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { +func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) { intersection := FlagMask var count int - for d, f := range p { - if d == dst { - intersection &= f.IntersectionRefs() + for dest, counter := range dc { + if dest == res.dst() { + intersection &= counter.SharedFlags() count++ continue } // Wildcard destinations affect all destinations for TupleOnly. - if d.addr == anyIPAddress || dst.addr == anyIPAddress { + if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress { // Only bitwise and the TupleOnlyFlag. - intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs()) + intersection &= ((^TupleOnlyFlag) | counter.SharedFlags()) count++ } } @@ -234,27 +105,29 @@ func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { return intersection, count } -// deviceNode is never empty. When it has no elements, it is removed from the +// deviceToDest maps NICs to destinations for which there are port reservations. +// +// deviceToDest is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]portNode +type deviceToDest map[tcpip.NICID]destToCounter -// isAvailable checks whether binding is possible by device. If not binding to a -// device, check against all FlagCounters. If binding to a specific device, check -// against the unspecified device and the provided device. +// isAvailable checks whether binding is possible by device. If not binding to +// a device, check against all FlagCounters. If binding to a specific device, +// check against the unspecified device and the provided device. // // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - flagBits := flags.Bits() - if bindToDevice == 0 { +func (dd deviceToDest) isAvailable(res Reservation) bool { + flagBits := res.Flags.Bits() + if res.BindToDevice == 0 { intersection := FlagMask - for _, p := range d { - i, c := p.intersectionRefs(dst) - if c == 0 { + for _, dest := range dd { + flags, count := dest.intersectionFlags(res) + if count == 0 { continue } - intersection &= i + intersection &= flags if intersection&flagBits == 0 { // Can't bind because the (addr,port) was // previously bound without reuse. @@ -266,18 +139,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti intersection := FlagMask - if p, ok := d[0]; ok { - var c int - intersection, c = p.intersectionRefs(dst) - if c > 0 && intersection&flagBits == 0 { + if dests, ok := dd[0]; ok { + var count int + intersection, count = dests.intersectionFlags(res) + if count > 0 && intersection&flagBits == 0 { return false } } - if p, ok := d[bindToDevice]; ok { - i, c := p.intersectionRefs(dst) - intersection &= i - if c > 0 && intersection&flagBits == 0 { + if dests, ok := dd[res.BindToDevice]; ok { + flags, count := dests.intersectionFlags(res) + intersection &= flags + if count > 0 && intersection&flagBits == 0 { return false } } @@ -285,18 +158,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti return true } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]deviceNode +// addrToDevice maps IP addresses to NICs that have port reservations. +type addrToDevice map[tcpip.Address]deviceToDest // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - if addr == anyIPAddress { - // If binding to the "any" address then check that there are no conflicts - // with all addresses. - for _, d := range b { - if !d.isAvailable(flags, bindToDevice, dst) { +func (ad addrToDevice) isAvailable(res Reservation) bool { + if res.Addr == anyIPAddress { + // If binding to the "any" address then check that there are no + // conflicts with all addresses. + for _, devices := range ad { + if !devices.isAvailable(res) { return false } } @@ -304,15 +177,15 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice } // Check that there is no conflict with the "any" address. - if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(flags, bindToDevice, dst) { + if devices, ok := ad[anyIPAddress]; ok { + if !devices.isAvailable(res) { return false } } // Check that this is no conflict with the provided address. - if d, ok := b[addr]; ok { - if !d.isAvailable(flags, bindToDevice, dst) { + if devices, ok := ad[res.Addr]; ok { + if !devices.isAvailable(res) { return false } } @@ -320,50 +193,93 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice return true } +// PortManager manages allocating, reserving and releasing ports. +type PortManager struct { + // mu protects allocatedPorts. + // LOCK ORDERING: mu > ephemeralMu. + mu sync.RWMutex + // allocatedPorts is a nesting of maps that ultimately map Reservations + // to FlagCounters describing whether the Reservation is valid and can + // be reused. + allocatedPorts map[portDescriptor]addrToDevice + + // ephemeralMu protects firstEphemeral and numEphemeral. + ephemeralMu sync.RWMutex + firstEphemeral uint16 + numEphemeral uint16 + + // hint is used to pick ports ephemeral ports in a stable order for + // a given port offset. + // + // hint must be accessed using the portHint/incPortHint helpers. + // TODO(gvisor.dev/issue/940): S/R this field. + hint uint32 +} + // NewPortManager creates new PortManager. func NewPortManager() *PortManager { - return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)} + return &PortManager{ + allocatedPorts: make(map[portDescriptor]addrToDevice), + // Match Linux's default ephemeral range. See: + // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842 + firstEphemeral: 32768, + numEphemeral: 28232, + } } +// PortTester indicates whether the passed in port is suitable. Returning an +// error causes the function to which the PortTester is passed to return that +// error. +type PortTester func(port uint16) (good bool, err tcpip.Error) + // PickEphemeralPort randomly chooses a starting point and iterates over all // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - offset := uint32(rand.Int31n(numEphemeralPorts)) - return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) +func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) { + pm.ephemeralMu.RLock() + firstEphemeral := pm.firstEphemeral + numEphemeral := pm.numEphemeral + pm.ephemeralMu.RUnlock() + + offset := uint16(rand.Int31n(int32(numEphemeral))) + return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } -// portHint atomically reads and returns the s.hint value. -func (s *PortManager) portHint() uint32 { - return atomic.LoadUint32(&s.hint) +// portHint atomically reads and returns the pm.hint value. +func (pm *PortManager) portHint() uint16 { + return uint16(atomic.LoadUint32(&pm.hint)) } -// incPortHint atomically increments s.hint by 1. -func (s *PortManager) incPortHint() { - atomic.AddUint32(&s.hint, 1) +// incPortHint atomically increments pm.hint by 1. +func (pm *PortManager) incPortHint() { + atomic.AddUint32(&pm.hint, 1) } -// PickEphemeralPortStable starts at the specified offset + s.portHint and +// PickEphemeralPortStable starts at the specified offset + pm.portHint and // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) +func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) { + pm.ephemeralMu.RLock() + firstEphemeral := pm.firstEphemeral + numEphemeral := pm.numEphemeral + pm.ephemeralMu.RUnlock() + + p, err := pickEphemeralPort(pm.portHint()+offset, firstEphemeral, numEphemeral, testPort) if err == nil { - s.incPortHint() + pm.incPortHint() } return p, err - } // pickEphemeralPort starts at the offset specified from the FirstEphemeral port // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - for i := uint32(0); i < count; i++ { - port = uint16(FirstEphemeral + (offset+i)%count) +func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { + for i := uint16(0); i < count; i++ { + port = first + (offset+i)%count ok, err := testPort(port) if err != nil { return 0, err @@ -377,144 +293,145 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui return 0, &tcpip.ErrNoPortAvailable{} } -// IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest)) -} - -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - for _, network := range networks { - desc := portDescriptor{network, transport, port} - if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, flags, bindToDevice, dst) { - return false - } - } - } - return true -} - // ReservePort marks a port/IP combination as reserved so that it cannot be // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. // -// An optional testPort closure can be passed in which if provided will be used -// to test if the picked port can be used. The function should return true if -// the port is safe to use, false otherwise. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) { - s.mu.Lock() - defer s.mu.Unlock() - - dst := makeDestination(dest) +// An optional PortTester can be passed in which if provided will be used to +// test if the picked port can be used. The function should return true if the +// port is safe to use, false otherwise. +func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { + pm.mu.Lock() + defer pm.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. - if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { + if res.Port != 0 { + if !pm.reserveSpecificPortLocked(res) { return 0, &tcpip.ErrPortInUse{} } - if testPort != nil && !testPort(port) { - s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst) - return 0, &tcpip.ErrPortInUse{} + if testPort != nil { + ok, err := testPort(res.Port) + if err != nil { + pm.releasePortLocked(res) + return 0, err + } + if !ok { + pm.releasePortLocked(res) + return 0, &tcpip.ErrPortInUse{} + } } - return port, nil + return res.Port, nil } // A port wasn't specified, so try to find one. - return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { - if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) { + return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + res.Port = p + if !pm.reserveSpecificPortLocked(res) { return false, nil } - if testPort != nil && !testPort(p) { - s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst) - return false, nil + if testPort != nil { + ok, err := testPort(p) + if err != nil { + pm.releasePortLocked(res) + return false, err + } + if !ok { + pm.releasePortLocked(res) + return false, nil + } } return true, nil }) } -// reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) { - return false +// reserveSpecificPortLocked tries to reserve the given port on all given +// protocols. +func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool { + // Make sure the port is available. + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + if addrs, ok := pm.allocatedPorts[desc]; ok { + if !addrs.isAvailable(res) { + return false + } + } } - flagBits := flags.Bits() - // Reserve port on all network protocols. - for _, network := range networks { - desc := portDescriptor{network, transport, port} - m, ok := s.allocatedPorts[desc] + flagBits := res.Flags.Bits() + dst := res.dst() + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] if !ok { - m = make(bindAddresses) - s.allocatedPorts[desc] = m + addrToDev = make(addrToDevice) + pm.allocatedPorts[desc] = addrToDev } - d, ok := m[addr] + devToDest, ok := addrToDev[res.Addr] if !ok { - d = make(deviceNode) - m[addr] = d + devToDest = make(deviceToDest) + addrToDev[res.Addr] = devToDest } - p := d[bindToDevice] - if p == nil { - p = make(portNode) + destToCntr := devToDest[res.BindToDevice] + if destToCntr == nil { + destToCntr = make(destToCounter) } - n := p[dst] - n.AddRef(flagBits) - p[dst] = n - d[bindToDevice] = p + counter := destToCntr[dst] + counter.AddRef(flagBits) + destToCntr[dst] = counter + devToDest[res.BindToDevice] = destToCntr } return true } // ReserveTuple adds a port reservation for the tuple on all given protocol. -func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { - flagBits := flags.Bits() - dst := makeDestination(dest) +func (pm *PortManager) ReserveTuple(res Reservation) bool { + flagBits := res.Flags.Bits() + dst := res.dst() - s.mu.Lock() - defer s.mu.Unlock() + pm.mu.Lock() + defer pm.mu.Unlock() // It is easier to undo the entire reservation, so if we find that the // tuple can't be fully added, finish and undo the whole thing. undo := false // Reserve port on all network protocols. - for _, network := range networks { - desc := portDescriptor{network, transport, port} - m, ok := s.allocatedPorts[desc] + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] if !ok { - m = make(bindAddresses) - s.allocatedPorts[desc] = m + addrToDev = make(addrToDevice) + pm.allocatedPorts[desc] = addrToDev } - d, ok := m[addr] + devToDest, ok := addrToDev[res.Addr] if !ok { - d = make(deviceNode) - m[addr] = d + devToDest = make(deviceToDest) + addrToDev[res.Addr] = devToDest } - p := d[bindToDevice] - if p == nil { - p = make(portNode) + destToCntr := devToDest[res.BindToDevice] + if destToCntr == nil { + destToCntr = make(destToCounter) } - n := p[dst] - if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 { + counter := destToCntr[dst] + if counter.TotalRefs() != 0 && counter.SharedFlags()&flagBits == 0 { // Tuple already exists. undo = true } - n.AddRef(flagBits) - p[dst] = n - d[bindToDevice] = p + counter.AddRef(flagBits) + destToCntr[dst] = counter + devToDest[res.BindToDevice] = destToCntr } if undo { // releasePortLocked decrements the counts (rather than setting // them to zero), so it will undo the incorrect incrementing // above. - s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst) + pm.releasePortLocked(res) return false } @@ -523,47 +440,71 @@ func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, trans // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) { - s.mu.Lock() - defer s.mu.Unlock() +func (pm *PortManager) ReleasePort(res Reservation) { + pm.mu.Lock() + defer pm.mu.Unlock() - s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest)) + pm.releasePortLocked(res) } -func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) { - for _, network := range networks { - desc := portDescriptor{network, transport, port} - if m, ok := s.allocatedPorts[desc]; ok { - d, ok := m[addr] - if !ok { - continue - } - p, ok := d[bindToDevice] - if !ok { - continue - } - n, ok := p[dst] - if !ok { - continue - } - n.DropRef(flags) - if n.TotalRefs() > 0 { - p[dst] = n - continue - } - delete(p, dst) - if len(p) > 0 { - continue - } - delete(d, bindToDevice) - if len(d) > 0 { - continue - } - delete(m, addr) - if len(m) > 0 { - continue - } - delete(s.allocatedPorts, desc) +func (pm *PortManager) releasePortLocked(res Reservation) { + dst := res.dst() + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] + if !ok { + continue } + devToDest, ok := addrToDev[res.Addr] + if !ok { + continue + } + destToCounter, ok := devToDest[res.BindToDevice] + if !ok { + continue + } + counter, ok := destToCounter[dst] + if !ok { + continue + } + counter.DropRef(res.Flags.Bits()) + if counter.TotalRefs() > 0 { + destToCounter[dst] = counter + continue + } + delete(destToCounter, dst) + if len(destToCounter) > 0 { + continue + } + delete(devToDest, res.BindToDevice) + if len(devToDest) > 0 { + continue + } + delete(addrToDev, res.Addr) + if len(addrToDev) > 0 { + continue + } + delete(pm.allocatedPorts, desc) + } +} + +// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in +// both IPv4 and IPv6. +func (pm *PortManager) PortRange() (uint16, uint16) { + pm.ephemeralMu.RLock() + defer pm.ephemeralMu.RUnlock() + return pm.firstEphemeral, pm.firstEphemeral + pm.numEphemeral - 1 +} + +// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range +// (inclusive). +func (pm *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error { + if start > end { + return &tcpip.ErrInvalidPortRange{} } + pm.ephemeralMu.Lock() + defer pm.ephemeralMu.Unlock() + pm.firstEphemeral = start + pm.numEphemeral = end - start + 1 + return nil } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index e70fbb72b..0f43dc8f8 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -329,16 +329,35 @@ func TestPortReservation(t *testing.T) { net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} for _, test := range test.actions { + first, _ := pm.PortRange() if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) + portRes := Reservation{ + Networks: net, + Transport: fakeTransNumber, + Addr: test.ip, + Port: test.port, + Flags: test.flags, + BindToDevice: test.device, + Dest: test.dest, + } + pm.ReleasePort(portRes) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) + portRes := Reservation{ + Networks: net, + Transport: fakeTransNumber, + Addr: test.ip, + Port: test.port, + Flags: test.flags, + BindToDevice: test.device, + Dest: test.dest, + } + gotPort, err := pm.ReservePort(portRes, nil /* testPort */) if diff := cmp.Diff(test.want, err); diff != "" { - t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff) + t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff) } - if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + if test.port == 0 && (gotPort == 0 || gotPort < first) { + t.Fatalf("ReservePort(%+v, _) = %d, want port number >= %d to be picked", portRes, gotPort, first) } } }) @@ -346,6 +365,11 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { + const ( + firstEphemeral = 32000 + numEphemeralPorts = 1000 + ) + for _, test := range []struct { name string f func(port uint16) (bool, tcpip.Error) @@ -369,17 +393,17 @@ func TestPickEphemeralPort(t *testing.T) { { name: "only-port-16042-available", f: func(port uint16) (bool, tcpip.Error) { - if port == FirstEphemeral+42 { + if port == firstEphemeral+42 { return true, nil } return false, nil }, - wantPort: FirstEphemeral + 42, + wantPort: firstEphemeral + 42, }, { name: "only-port-under-16000-available", f: func(port uint16) (bool, tcpip.Error) { - if port < FirstEphemeral { + if port < firstEphemeral { return true, nil } return false, nil @@ -389,6 +413,9 @@ func TestPickEphemeralPort(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() + if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { + t.Fatalf("failed to set ephemeral port range: %s", err) + } port, err := pm.PickEphemeralPort(test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) @@ -401,6 +428,11 @@ func TestPickEphemeralPort(t *testing.T) { } func TestPickEphemeralPortStable(t *testing.T) { + const ( + firstEphemeral = 32000 + numEphemeralPorts = 1000 + ) + for _, test := range []struct { name string f func(port uint16) (bool, tcpip.Error) @@ -424,17 +456,17 @@ func TestPickEphemeralPortStable(t *testing.T) { { name: "only-port-16042-available", f: func(port uint16) (bool, tcpip.Error) { - if port == FirstEphemeral+42 { + if port == firstEphemeral+42 { return true, nil } return false, nil }, - wantPort: FirstEphemeral + 42, + wantPort: firstEphemeral + 42, }, { name: "only-port-under-16000-available", f: func(port uint16) (bool, tcpip.Error) { - if port < FirstEphemeral { + if port < firstEphemeral { return true, nil } return false, nil @@ -444,7 +476,10 @@ func TestPickEphemeralPortStable(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() - portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) + if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { + t.Fatalf("failed to set ephemeral port range: %s", err) + } + portOffset := uint16(rand.Int31n(int32(numEphemeralPorts))) port, err := pm.PickEphemeralPortStable(portOffset, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index cdb435644..3f083928f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -407,12 +407,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data.Size()) + length := uint16(len(tcpHeader) + pkt.Data().Size()) xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d63e9757c..0e8b90c9b 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 740bdac28..47796a6ba 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -99,12 +99,11 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi } // ndpDADEvent is a set of parameters that was passed to -// ndpDispatcher.OnDuplicateAddressDetectionStatus. +// ndpDispatcher.OnDuplicateAddressDetectionResult. type ndpDADEvent struct { - nicID tcpip.NICID - addr tcpip.Address - resolved bool - err tcpip.Error + nicID tcpip.NICID + addr tcpip.Address + res stack.DADResult } type ndpRouterEvent struct { @@ -173,14 +172,13 @@ type ndpDispatcher struct { dhcpv6ConfigurationC chan ndpDHCPv6Event } -// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) { +// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionResult. +func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) { if n.dadC != nil { n.dadC <- ndpDADEvent{ nicID, addr, - resolved, - err, + res, } } } @@ -311,8 +309,8 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string { - return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) string { + return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, res: res}, e, cmp.AllowUnexported(e)) } // TestDADDisabled tests that an address successfully resolves immediately @@ -344,8 +342,8 @@ func TestDADDisabled(t *testing.T) { // DAD on it. select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -491,8 +489,8 @@ func TestDADResolve(t *testing.T) { case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { @@ -573,7 +571,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: header.IPv6Any, + Dst: snmc, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -594,9 +596,10 @@ func TestDADFail(t *testing.T) { const nicID = 1 tests := []struct { - name string - rxPkt func(e *channel.Endpoint, tgt tcpip.Address) - getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + name string + rxPkt func(e *channel.Endpoint, tgt tcpip.Address) + getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + expectedHolderLinkAddress tcpip.LinkAddress }{ { name: "RxSolicit", @@ -604,6 +607,7 @@ func TestDADFail(t *testing.T) { getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborSolicit }, + expectedHolderLinkAddress: "", }, { name: "RxAdvert", @@ -619,7 +623,11 @@ func TestDADFail(t *testing.T) { na.Options().Serialize(header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(linkAddr1), }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: tgt, + Dst: header.IPv6AllNodesMulticastAddress, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -634,6 +642,7 @@ func TestDADFail(t *testing.T) { getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborAdvert }, + expectedHolderLinkAddress: linkAddr1, }, } @@ -683,8 +692,8 @@ func TestDADFail(t *testing.T) { // something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { @@ -782,8 +791,8 @@ func TestDADStop(t *testing.T) { // time + extra 1s buffer, something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } @@ -844,8 +853,8 @@ func TestSetNDPConfigurations(t *testing.T) { expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatalf("expected DAD event for %s", addr) @@ -936,8 +945,8 @@ func TestSetNDPConfigurations(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { @@ -973,7 +982,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo } opts := ra.Options() opts.Serialize(optSer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: ip, + Dst: header.IPv6AllNodesMulticastAddress, + })) payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ @@ -1951,8 +1964,8 @@ func TestAutoGenTempAddr(t *testing.T) { select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2157,8 +2170,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2245,8 +2258,8 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // address to be generated. select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2711,8 +2724,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { t.Helper() clock.Advance(dupAddrTransmits * retransmitTimer) - if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } @@ -2742,8 +2755,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { rxNDPSolicit(e, addr.Address) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -3841,26 +3854,26 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) { + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, res); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") } } - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, res); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -3917,7 +3930,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // generated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true) + expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) // The stable address will be assigned throughout the test. return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} @@ -3992,7 +4005,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, false, nil) + expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4002,7 +4015,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{}) + expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4015,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if maxRetries+1 > numFailures { addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, true) + expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { t.Fatal(mismatch) } @@ -4132,8 +4145,8 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -4231,8 +4244,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -4243,8 +4256,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index afff1b434..48bb75e2f 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -59,21 +59,24 @@ const ( infiniteDuration = time.Duration(math.MaxInt64) ) -// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor -// entries. The UpdatedAtNanos field is ignored due to a lack of a -// deterministic method to predict the time that an event will be dispatched. -func entryDiffOpts() []cmp.Option { +// unorderedEventsDiffOpts returns options passed to cmp.Diff to sort slices of +// events for cases where ordering must be ignored. +func unorderedEventsDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), + cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 + }), } } -// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to -// sort slices of entries for cases where ordering must be ignored. -func entryDiffOptsWithSort() []cmp.Option { - return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - })) +// unorderedEntriesDiffOpts returns options passed to cmp.Diff to sort slices of +// entries for cases where ordering must be ignored. +func unorderedEntriesDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } } func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { @@ -280,48 +283,105 @@ func TestNeighborCacheSetConfig(t *testing.T) { } } -func TestNeighborCacheEntry(t *testing.T) { - c := DefaultNUDConfigurations() - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) +func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry, removed []NeighborEntry) error { + var gotLinkResolutionResult LinkResolutionResult - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { + gotLinkResolutionResult = r + }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + return fmt.Errorf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - clock.Advance(typicalLatency) + { + var wantEvents []testEntryEventInfo - wantEvents := []testEntryEventInfo{ - { + for _, removedEntry := range removed { + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: 1, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }) + } + + wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAtNanos: clock.NowNanoseconds(), }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + }) + + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + } + + clock.Advance(typicalLatency) + + select { + case <-ch: + default: + return fmt.Errorf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) + } + wantLinkResolutionResult := LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil} + if diff := cmp.Diff(wantLinkResolutionResult, gotLinkResolutionResult); diff != "" { + return fmt.Errorf("got link resolution result mismatch (-want +got):\n%s", diff) + } + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, }, - }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + + return nil +} + +func addReachableEntry(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry) error { + return addReachableEntryWithRemoved(nudDisp, clock, linkRes, entry, nil /* removed */) +} + +func TestNeighborCacheEntry(t *testing.T) { + c := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + linkRes := newTestNeighborResolver(&nudDisp, c, clock) + + entry, ok := linkRes.entries.entry(0) + if !ok { + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") + } + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { @@ -345,41 +405,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } - - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - clock.Advance(typicalLatency) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } linkRes.neigh.removeEntry(entry.Addr) @@ -390,14 +419,15 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), }, }, } nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.mu.events) nudDisp.mu.Unlock() if diff != "" { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) @@ -439,18 +469,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // Fill the neighbor cache to capacity to verify the LRU eviction strategy is // working properly after the entry removal. for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { - // Add a new entry - entry, ok := c.linkRes.entries.entry(i) - if !ok { - return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) - - var wantEvents []testEntryEventInfo + var removedEntries []NeighborEntry // When beyond the full capacity, the cache will evict an entry as per the // LRU eviction strategy. Note that the number of static entries should not @@ -458,63 +477,40 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if i >= neighborCacheSize+opts.startAtEntryIndex { removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) if !ok { - return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize) + return fmt.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize) } - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - }, - }) + removedEntries = append(removedEntries, removedEntry) } - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, testEntryEventInfo{ - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }) - - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + entry, ok := c.linkRes.entries.entry(i) + if !ok { + return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) + } + if err := addReachableEntryWithRemoved(c.nudDisp, c.clock, c.linkRes, entry, removedEntries); err != nil { + return fmt.Errorf("addReachableEntryWithRemoved(...) = %s", err) } } // Expect to find only the most recent entries. The order of entries reported // by entries() is nondeterministic, so entries have to be sorted before // comparison. - wantUnsortedEntries := opts.wantStaticEntries + wantUnorderedEntries := opts.wantStaticEntries for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) + return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) } + durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos, } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnorderedEntries, c.linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -560,38 +556,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Remove the entry @@ -603,14 +571,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) c.nudDisp.mu.events = nil c.nudDisp.mu.Unlock() if diff != "" { @@ -636,33 +605,36 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // Add a static entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } staticLinkAddr := entry.LinkAddr + "static" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), + }, }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } - // Remove the static entry that was just added + // Add a duplicate static entry with the same link address. c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" { @@ -680,48 +652,56 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a static entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } staticLinkAddr := entry.LinkAddr + "static" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), + }, }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } // Add a duplicate entry with a different link address staticLinkAddr += "duplicate" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { wantEvents := []testEntryEventInfo{ { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -742,45 +722,51 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { // Add a static entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } staticLinkAddr := entry.LinkAddr + "static" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), + }, }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } // Remove the static entry that was just added c.linkRes.neigh.removeEntry(entry.Addr) + { wantEvents := []testEntryEventInfo{ { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) c.nudDisp.mu.events = nil c.nudDisp.mu.Unlock() if diff != "" { @@ -812,66 +798,41 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(typicalLatency) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Override the entry with a static one using the same address staticLinkAddr := entry.LinkAddr + "static" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { wantEvents := []testEntryEventInfo{ { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) c.nudDisp.mu.events = nil c.nudDisp.mu.Unlock() if diff != "" { @@ -883,9 +844,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } @@ -905,7 +867,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) @@ -913,40 +875,45 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + if diff := cmp.Diff(want, e); diff != "" { t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), + }, }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } opts := overflowOptions{ startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } @@ -965,39 +932,10 @@ func TestNeighborCacheClear(t *testing.T) { // Add a dynamic entry. entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Add a static entry. @@ -1009,14 +947,15 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAtNanos: clock.NowNanoseconds(), }, }, } nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.mu.events) nudDisp.mu.events = nil nudDisp.mu.Unlock() if diff != "" { @@ -1028,30 +967,32 @@ func TestNeighborCacheClear(t *testing.T) { linkRes.neigh.clear() // Remove events dispatched from clear() have no deterministic order so they - // need to be sorted beforehand. - wantUnsortedEvents := []testEntryEventInfo{ + // need to be sorted before comparison. + wantUnorderedEvents := []testEntryEventInfo{ { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), }, }, { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAtNanos: clock.NowNanoseconds(), }, }, } nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(wantUnsortedEvents, nudDisp.mu.events, eventDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnorderedEvents, nudDisp.mu.events, unorderedEventsDiffOpts()...); diff != "" { t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1071,56 +1012,30 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(typicalLatency) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Clear the cache. c.linkRes.neigh.clear() + { wantEvents := []testEntryEventInfo{ { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) c.nudDisp.mu.events = nil c.nudDisp.mu.Unlock() if diff != "" { @@ -1147,10 +1062,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - frequentlyUsedEntry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } + startedAt := clock.NowNanoseconds() // The following logic is very similar to overflowCache, but // periodically refreshes the frequently used entry. @@ -1159,50 +1071,18 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { for i := 0; i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) + t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } } + frequentlyUsedEntry, ok := linkRes.entries.entry(0) + if !ok { + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") + } + // Keep adding more entries for i := neighborCacheSize; i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry @@ -1214,63 +1094,17 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } // An entry should have been removed, as per the LRU eviction strategy removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1) - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - }, - }, - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, + t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize+1) } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + + if err := addReachableEntryWithRemoved(&nudDisp, clock, linkRes, entry, []NeighborEntry{removedEntry}); err != nil { + t.Fatalf("addReachableEntryWithRemoved(...) = %s", err) } } @@ -1282,23 +1116,27 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { Addr: frequentlyUsedEntry.Addr, LinkAddr: frequentlyUsedEntry.LinkAddr, State: Reachable, + // Can be inferred since the frequently used entry is the first to + // be created and transitioned to Reachable. + UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(), }, } for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) - } - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + }) } - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -1350,17 +1188,18 @@ func TestNeighborCacheConcurrent(t *testing.T) { for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { entry, ok := linkRes.entries.entry(i) if !ok { - t.Errorf("linkRes.entries.entry(%d) not found", i) + t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + }) } - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1372,44 +1211,12 @@ func TestNeighborCacheReplace(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - // Add an entry entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } - - // Verify the entry exists - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) - } - if t.Failed() { - t.FailNow() - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Notify of a link address change @@ -1417,7 +1224,7 @@ func TestNeighborCacheReplace(t *testing.T) { { entry, ok := linkRes.entries.entry(1) if !ok { - t.Fatal("linkRes.entries.entry(1) not found") + t.Fatal("got linkRes.entries.entry(1) = _, false, want = true") } updatedLinkAddr = entry.LinkAddr } @@ -1437,29 +1244,31 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, + UpdatedAtNanos: clock.NowNanoseconds(), } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } - clock.Advance(config.DelayFirstProbeTime + typicalLatency) } + clock.Advance(config.DelayFirstProbeTime + typicalLatency) + // Verify that the neighbor is now reachable. { e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - clock.Advance(typicalLatency) if err != nil { t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } @@ -1479,25 +1288,12 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } // First, sanity check that resolution is working - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) @@ -1505,11 +1301,12 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), } - if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } @@ -1524,14 +1321,14 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) } } @@ -1555,7 +1352,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { @@ -1564,7 +1361,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1572,7 +1369,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) } } @@ -1580,14 +1377,15 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { // failing to perform address resolution. func TestNeighborCacheRetryResolution(t *testing.T) { config := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nil, config, clock) + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Simulate a faulty link. linkRes.dropReplies = true entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } // Perform address resolution with a faulty link, which will fail. @@ -1598,27 +1396,75 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + } + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) } - } - wantEntries := []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - }, - } - if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { - t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + } + + { + wantEntries := []NeighborEntry{ + { + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + } + if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { + t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + } } // Retry address resolution with a working link. @@ -1635,28 +1481,74 @@ func TestNeighborCacheRetryResolution(t *testing.T) { if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) } + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + } + clock.Advance(typicalLatency) select { case <-ch: - if !ok { - t.Fatal("expected successful address resolution") + default: + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) + } + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }, } - reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - if reachableEntry.Addr != entry.Addr { - t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) + } + + { + gotEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) + if err != nil { + t.Fatalf("linkRes.neigh.entry(%s, '', _): %s", entry.Addr, err) } - if reachableEntry.LinkAddr != entry.LinkAddr { - t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr) + + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), } - if reachableEntry.State != Reachable { - t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) + if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { + t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) } - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } } @@ -1674,7 +1566,7 @@ func BenchmarkCacheClear(b *testing.B) { for i := 0; i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { - b.Fatalf("linkRes.entries.entry(%d) not found", i) + b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { @@ -1683,13 +1575,13 @@ func BenchmarkCacheClear(b *testing.B) { } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + b.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { case <-ch: default: - b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) } } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index baae7dfe1..bb2b2d705 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -18,13 +18,11 @@ import ( "fmt" "math" "math/rand" - "strings" "sync" "testing" "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -52,23 +50,6 @@ func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { clock.Advance(immediateDuration) } -// eventDiffOpts are the options passed to cmp.Diff to compare entry events. -// The UpdatedAtNanos field is ignored due to a lack of a deterministic method -// to predict the time that an event will be dispatched. -func eventDiffOpts() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), - } -} - -// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to -// sort slices of events for cases where ordering must be ignored. -func eventDiffOptsWithSort() []cmp.Option { - return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 - })) -} - // The following unit tests exercise every state transition and verify its // behavior with RFC 4681 and RFC 7048. // diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f9323d545..62f7c880e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -725,12 +725,12 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.mu.RUnlock() n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) return } n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { @@ -881,7 +881,7 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) + transHeader, ok := pkt.Data().PullUp(8) if !ok { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 4f013b212..8f288675d 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -59,7 +59,7 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. + // data holds the payload of the packet. // // For inbound packets, Data is initially the whole packet. Then gets moved to // headers via PacketHeader.Consume, when the packet is being parsed. @@ -69,7 +69,7 @@ type PacketBuffer struct { // // The bytes backing Data are immutable, a.k.a. users shouldn't write to its // backing storage. - Data buffer.VectorisedView + data buffer.VectorisedView // headers stores metadata about each header. headers [numHeaderType]headerInfo @@ -127,7 +127,7 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - Data: opts.Data, + data: opts.Data, } if opts.ReserveHeaderBytes != 0 { pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) @@ -184,13 +184,18 @@ func (pk *PacketBuffer) HeaderSize() int { // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.Data.Size() + return pk.HeaderSize() + pk.data.Size() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize + return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize +} + +// Data returns the handle to data portion of pk. +func (pk *PacketBuffer) Data() PacketData { + return PacketData{pk: pk} } // Views returns the underlying storage of the whole packet. @@ -204,7 +209,7 @@ func (pk *PacketBuffer) Views() []buffer.View { } } - dataViews := pk.Data.Views() + dataViews := pk.data.Views() var vs []buffer.View if useHeader { @@ -242,11 +247,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum if h.buf != nil { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.Data.PullUp(size) + v, ok := pk.data.PullUp(size) if !ok { return } - pk.Data.TrimFront(size) + pk.data.TrimFront(size) h.buf = v return h.buf, true } @@ -258,7 +263,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), + data: pk.data.Clone(nil), headers: pk.headers, header: pk.header, Hash: pk.Hash, @@ -339,13 +344,234 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { return h.pk.consume(h.typ, size) } +// PacketData represents the data portion of a PacketBuffer. +type PacketData struct { + pk *PacketBuffer +} + +// PullUp returns a contiguous view of size bytes from the beginning of d. +// Callers should not write to or keep the view for later use. +func (d PacketData) PullUp(size int) (buffer.View, bool) { + return d.pk.data.PullUp(size) +} + +// TrimFront removes count from the beginning of d. It panics if count > +// d.Size(). +func (d PacketData) TrimFront(count int) { + d.pk.data.TrimFront(count) +} + +// CapLength reduces d to at most length bytes. +func (d PacketData) CapLength(length int) { + d.pk.data.CapLength(length) +} + +// Views returns the underlying storage of d in a slice of Views. Caller should +// not modify the returned slice. +func (d PacketData) Views() []buffer.View { + return d.pk.data.Views() +} + +// AppendView appends v into d, taking the ownership of v. +func (d PacketData) AppendView(v buffer.View) { + d.pk.data.AppendView(v) +} + +// ReadFromData moves at most count bytes from the beginning of srcData to the +// end of d and returns the number of bytes moved. +func (d PacketData) ReadFromData(srcData PacketData, count int) int { + return srcData.pk.data.ReadToVV(&d.pk.data, count) +} + +// ReadFromVV moves at most count bytes from the beginning of srcVV to the end +// of d and returns the number of bytes moved. +func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { + return srcVV.ReadToVV(&d.pk.data, count) +} + +// Size returns the number of bytes in the data payload of the packet. +func (d PacketData) Size() int { + return d.pk.data.Size() +} + +// AsRange returns a Range representing the current data payload of the packet. +func (d PacketData) AsRange() Range { + return Range{ + pk: d.pk, + offset: d.pk.HeaderSize(), + length: d.Size(), + } +} + +// ExtractVV returns a VectorisedView of d. This method has the semantic to +// destruct the underlying packet, hence the packet cannot be used again. +// +// This method exists for compatibility between PacketBuffer and VectorisedView. +// It may be removed later and should be used with care. +func (d PacketData) ExtractVV() buffer.VectorisedView { + return d.pk.data +} + +// Replace replaces the data portion of the packet with vv, taking the ownership +// of vv. +// +// This method exists for compatibility between PacketBuffer and VectorisedView. +// It may be removed later and should be used with care. +func (d PacketData) Replace(vv buffer.VectorisedView) { + d.pk.data = vv +} + +// Range represents a contiguous subportion of a PacketBuffer. +type Range struct { + pk *PacketBuffer + offset int + length int +} + +// Size returns the number of bytes in r. +func (r Range) Size() int { + return r.length +} + +// SubRange returns a new Range starting at off bytes of r. It returns an empty +// range if off is out-of-bounds. +func (r Range) SubRange(off int) Range { + if off > r.length { + return Range{pk: r.pk} + } + return Range{ + pk: r.pk, + offset: r.offset + off, + length: r.length - off, + } +} + +// Capped returns a new Range with the same starting point of r and length +// capped at max. +func (r Range) Capped(max int) Range { + if r.length <= max { + return r + } + return Range{ + pk: r.pk, + offset: r.offset, + length: max, + } +} + +// AsView returns the backing storage of r if possible. It will allocate a new +// View if r spans multiple pieces internally. Caller should not write to the +// returned View in any way. +func (r Range) AsView() buffer.View { + var allocated bool + var v buffer.View + r.iterate(func(b []byte) { + if v == nil { + // v has not been assigned, allowing first view to be returned. + v = b + } else { + // v has been assigned. This range spans more than a view, a new view + // needs to be allocated. + if !allocated { + allocated = true + all := make([]byte, 0, r.length) + all = append(all, v...) + v = all + } + v = append(v, b...) + } + }) + return v +} + +// ToOwnedView returns a owned copy of data in r. +func (r Range) ToOwnedView() buffer.View { + if r.length == 0 { + return nil + } + all := make([]byte, 0, r.length) + r.iterate(func(b []byte) { + all = append(all, b...) + }) + return all +} + +// Checksum calculates the RFC 1071 checksum for the underlying bytes of r. +func (r Range) Checksum() uint16 { + var c header.Checksumer + r.iterate(c.Add) + return c.Checksum() +} + +// iterate calls fn for each piece in r. fn is always called with a non-empty +// slice. +func (r Range) iterate(fn func([]byte)) { + w := window{ + offset: r.offset, + length: r.length, + } + // Header portion. + for i := range r.pk.headers { + if b := w.process(r.pk.headers[i].buf); len(b) > 0 { + fn(b) + } + if w.isDone() { + break + } + } + // Data portion. + if !w.isDone() { + for _, v := range r.pk.data.Views() { + if b := w.process(v); len(b) > 0 { + fn(b) + } + if w.isDone() { + break + } + } + } +} + +// window represents contiguous region of byte stream. User would call process() +// to input bytes, and obtain a subslice that is inside the window. +type window struct { + offset int + length int +} + +// isDone returns true if the window has passed and further process() calls will +// always return an empty slice. This can be used to end processing early. +func (w *window) isDone() bool { + return w.length == 0 +} + +// process feeds b in and returns a subslice that is inside the window. The +// returned slice will be a subslice of b, and it does not keep b after method +// returns. This method may return an empty slice if nothing in b is inside the +// window. +func (w *window) process(b []byte) (inWindow []byte) { + if w.offset >= len(b) { + w.offset -= len(b) + return nil + } + if w.offset > 0 { + b = b[w.offset:] + w.offset = 0 + } + if w.length < len(b) { + b = b[:w.length] + } + w.length -= len(b) + return b +} + // PayloadSince returns packet payload starting from and including a particular // header. // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.Data.Size() + size := h.pk.data.Size() for _, hinfo := range h.pk.headers[h.typ:] { size += len(hinfo.buf) } @@ -356,7 +582,7 @@ func PayloadSince(h PacketHeader) buffer.View { v = append(v, hinfo.buf...) } - for _, view := range h.pk.Data.Views() { + for _, view := range h.pk.data.Views() { v = append(v, view...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index c6fa8da5f..6728370c3 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -15,9 +15,11 @@ package stack import ( "bytes" + "fmt" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) func TestPacketHeaderPush(t *testing.T) { @@ -110,7 +112,7 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkData(t, pk, test.data) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), concatViews(test.link, test.network, test.transport, test.data)) // Check the after values for each header. @@ -204,7 +206,7 @@ func TestPacketHeaderConsume(t *testing.T) { transport = test.data[test.link+test.network:][:test.transport] payload = test.data[allHdrSize:] ) - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkData(t, pk, payload) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) // Check the after values for each header. checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) @@ -340,6 +342,158 @@ func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { } } +func TestPacketBufferData(t *testing.T) { + for _, tc := range []struct { + name string + makePkt func(*testing.T) *PacketBuffer + data string + }{ + { + name: "inbound packet", + makePkt: func(*testing.T) *PacketBuffer { + pkt := NewPacketBuffer(PacketBufferOptions{ + Data: vv("aabbbbccccccDATA"), + }) + pkt.LinkHeader().Consume(2) + pkt.NetworkHeader().Consume(4) + pkt.TransportHeader().Consume(6) + return pkt + }, + data: "DATA", + }, + { + name: "outbound packet", + makePkt: func(*testing.T) *PacketBuffer { + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: 12, + Data: vv("DATA"), + }) + copy(pkt.TransportHeader().Push(6), []byte("cccccc")) + copy(pkt.NetworkHeader().Push(4), []byte("bbbb")) + copy(pkt.LinkHeader().Push(2), []byte("aa")) + return pkt + }, + data: "DATA", + }, + } { + t.Run(tc.name, func(t *testing.T) { + // PullUp + for _, n := range []int{1, len(tc.data)} { + t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + v, ok := pkt.Data().PullUp(n) + wantV := []byte(tc.data)[:n] + if !ok || !bytes.Equal(v, wantV) { + t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV) + } + }) + } + t.Run("PullUpOutOfBounds", func(t *testing.T) { + n := len(tc.data) + 1 + pkt := tc.makePkt(t) + v, ok := pkt.Data().PullUp(n) + if ok || v != nil { + t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok) + } + }) + + // TrimFront + for _, n := range []int{1, len(tc.data)} { + t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + pkt.Data().TrimFront(n) + + checkData(t, pkt, []byte(tc.data)[n:]) + }) + } + + // CapLength + for _, n := range []int{0, 1, len(tc.data)} { + t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + pkt.Data().CapLength(n) + + want := []byte(tc.data) + if n < len(want) { + want = want[:n] + } + checkData(t, pkt, want) + }) + } + + // Views + t.Run("Views", func(t *testing.T) { + pkt := tc.makePkt(t) + checkData(t, pkt, []byte(tc.data)) + }) + + // AppendView + t.Run("AppendView", func(t *testing.T) { + s := "APPEND" + + pkt := tc.makePkt(t) + pkt.Data().AppendView(buffer.View(s)) + + checkData(t, pkt, []byte(tc.data+s)) + }) + + // ReadFromData/VV + for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { + t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { + s := "TO READ" + otherPkt := NewPacketBuffer(PacketBufferOptions{ + Data: vv(s, s), + }) + s += s + + pkt := tc.makePkt(t) + pkt.Data().ReadFromData(otherPkt.Data(), n) + + if n < len(s) { + s = s[:n] + } + checkData(t, pkt, []byte(tc.data+s)) + }) + t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { + s := "TO READ" + srcVV := vv(s, s) + s += s + + pkt := tc.makePkt(t) + pkt.Data().ReadFromVV(&srcVV, n) + + if n < len(s) { + s = s[:n] + } + checkData(t, pkt, []byte(tc.data+s)) + }) + } + + // ExtractVV + t.Run("ExtractVV", func(t *testing.T) { + pkt := tc.makePkt(t) + extractedVV := pkt.Data().ExtractVV() + + got := extractedVV.ToOwnedView() + want := []byte(tc.data) + if !bytes.Equal(got, want) { + t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) + } + }) + + // Replace + t.Run("Replace", func(t *testing.T) { + s := "REPLACED" + + pkt := tc.makePkt(t) + pkt.Data().Replace(vv(s)) + + checkData(t, pkt, []byte(s)) + }) + }) + } +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -356,7 +510,7 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkData(t, pk, data) checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) // Check the initial values for each header. checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) @@ -383,6 +537,70 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { } } +func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { + t.Helper() + if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { + t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + } + if got := pkt.Data().Size(); got != len(want) { + t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) + } + + t.Run("AsRange", func(t *testing.T) { + // Full range + checkRange(t, pkt.Data().AsRange(), want) + + // SubRange + for _, off := range []int{0, 1, len(want), len(want) + 1} { + t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) { + // Empty when off is greater than the size of range. + var sub []byte + if off < len(want) { + sub = want[off:] + } + checkRange(t, pkt.Data().AsRange().SubRange(off), sub) + }) + } + + // Capped + for _, n := range []int{0, 1, len(want), len(want) + 1} { + t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) { + sub := want + if n < len(sub) { + sub = sub[:n] + } + checkRange(t, pkt.Data().AsRange().Capped(n), sub) + }) + } + }) +} + +func checkRange(t *testing.T, r Range, data []byte) { + if got, want := r.Size(), len(data); got != want { + t.Errorf("r.Size() = %d, want %d", got, want) + } + if got := r.AsView(); !bytes.Equal(got, data) { + t.Errorf("r.AsView() = %x, want %x", got, data) + } + if got := r.ToOwnedView(); !bytes.Equal(got, data) { + t.Errorf("r.ToOwnedView() = %x, want %x", got, data) + } + if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want { + t.Errorf("r.Checksum() = %x, want %x", got, want) + } +} + +func vv(pieces ...string) buffer.VectorisedView { + var views []buffer.View + var size int + for _, p := range pieces { + v := buffer.View([]byte(p)) + size += len(v) + views = append(views, v) + } + return buffer.NewVectorisedView(size, views) +} + func makeView(size int) buffer.View { b := byte(size) return bytes.Repeat([]byte{b}, size) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 43e9e4beb..85f0f471a 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -852,18 +852,46 @@ type InjectableLinkEndpoint interface { InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// DADResult is the result of a duplicate address detection process. -type DADResult struct { - // Resolved is true when DAD completed without detecting a duplicate address - // on the link. - // - // Ignored when Err is non-nil. - Resolved bool +// DADResult is a marker interface for the result of a duplicate address +// detection process. +type DADResult interface { + isDADResult() +} + +var _ DADResult = (*DADSucceeded)(nil) + +// DADSucceeded indicates DAD completed without finding any duplicate addresses. +type DADSucceeded struct{} - // Err is an error encountered while performing DAD. +func (*DADSucceeded) isDADResult() {} + +var _ DADResult = (*DADError)(nil) + +// DADError indicates DAD hit an error. +type DADError struct { Err tcpip.Error } +func (*DADError) isDADResult() {} + +var _ DADResult = (*DADAborted)(nil) + +// DADAborted indicates DAD was aborted. +type DADAborted struct{} + +func (*DADAborted) isDADResult() {} + +var _ DADResult = (*DADDupAddrDetected)(nil) + +// DADDupAddrDetected indicates DAD detected a duplicate address. +type DADDupAddrDetected struct { + // HolderLinkAddress is the link address of the node that holds the duplicate + // address. + HolderLinkAddress tcpip.LinkAddress +} + +func (*DADDupAddrDetected) isDADResult() {} + // DADCompletionHandler is a handler for DAD completion. type DADCompletionHandler func(DADResult) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index de94ddfda..53370c354 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -813,6 +813,18 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { return forwardingProtocol.Forwarding() } +// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in +// both IPv4 and IPv6. +func (s *Stack) PortRange() (uint16, uint16) { + return s.PortManager.PortRange() +} + +// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range +// (inclusive). +func (s *Stack) SetPortRange(start uint16, end uint16) tcpip.Error { + return s.PortManager.SetPortRange(start, end) +} + // SetRouteTable assigns the route table to be used by this stack. It // specifies which NIC to use for given destination address ranges. // diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8e39e828c..880219007 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -137,11 +137,11 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) + pkt.Data().TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -2605,7 +2605,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { + if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } @@ -3289,7 +3289,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { + if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } @@ -4294,7 +4294,7 @@ func TestWritePacketToRemote(t *testing.T) { if pkt.Route.RemoteLinkAddress != linkAddr2 { t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) } - if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" { t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) } }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e799f9290..e188efccb 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -359,7 +359,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[0] } - if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent { + if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { return mpep.endpoints[len(mpep.endpoints)-1] } @@ -410,7 +410,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { return &tcpip.ErrPortInUse{} } } @@ -429,7 +429,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { return &tcpip.ErrPortInUse{} } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 01a4389e3..87ea09a5e 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1258,44 +1258,38 @@ func (m *MultiCounterStat) IncrementBy(v uint64) { type ICMPv4PacketStats struct { // LINT.IfChange(ICMPv4PacketStats) - // Echo is the total number of ICMPv4 echo packets counted. - Echo *StatCounter + // EchoRequest is the number of ICMPv4 echo packets counted. + EchoRequest *StatCounter - // EchoReply is the total number of ICMPv4 echo reply packets counted. + // EchoReply is the number of ICMPv4 echo reply packets counted. EchoReply *StatCounter - // DstUnreachable is the total number of ICMPv4 destination unreachable - // packets counted. + // DstUnreachable is the number of ICMPv4 destination unreachable packets + // counted. DstUnreachable *StatCounter - // SrcQuench is the total number of ICMPv4 source quench packets - // counted. + // SrcQuench is the number of ICMPv4 source quench packets counted. SrcQuench *StatCounter - // Redirect is the total number of ICMPv4 redirect packets counted. + // Redirect is the number of ICMPv4 redirect packets counted. Redirect *StatCounter - // TimeExceeded is the total number of ICMPv4 time exceeded packets - // counted. + // TimeExceeded is the number of ICMPv4 time exceeded packets counted. TimeExceeded *StatCounter - // ParamProblem is the total number of ICMPv4 parameter problem packets - // counted. + // ParamProblem is the number of ICMPv4 parameter problem packets counted. ParamProblem *StatCounter - // Timestamp is the total number of ICMPv4 timestamp packets counted. + // Timestamp is the number of ICMPv4 timestamp packets counted. Timestamp *StatCounter - // TimestampReply is the total number of ICMPv4 timestamp reply packets - // counted. + // TimestampReply is the number of ICMPv4 timestamp reply packets counted. TimestampReply *StatCounter - // InfoRequest is the total number of ICMPv4 information request - // packets counted. + // InfoRequest is the number of ICMPv4 information request packets counted. InfoRequest *StatCounter - // InfoReply is the total number of ICMPv4 information reply packets - // counted. + // InfoReply is the number of ICMPv4 information reply packets counted. InfoReply *StatCounter // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4PacketStats) @@ -1307,12 +1301,11 @@ type ICMPv4SentPacketStats struct { ICMPv4PacketStats - // Dropped is the total number of ICMPv4 packets dropped due to link - // layer errors. + // Dropped is the number of ICMPv4 packets dropped due to link layer errors. Dropped *StatCounter - // RateLimited is the total number of ICMPv4 packets dropped due to - // rate limit being exceeded. + // RateLimited is the number of ICMPv4 packets dropped due to rate limit being + // exceeded. RateLimited *StatCounter // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4SentPacketStats) @@ -1324,7 +1317,7 @@ type ICMPv4ReceivedPacketStats struct { ICMPv4PacketStats - // Invalid is the total number of invalid ICMPv4 packets received. + // Invalid is the number of invalid ICMPv4 packets received. Invalid *StatCounter // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4ReceivedPacketStats) @@ -1347,59 +1340,50 @@ type ICMPv4Stats struct { type ICMPv6PacketStats struct { // LINT.IfChange(ICMPv6PacketStats) - // EchoRequest is the total number of ICMPv6 echo request packets - // counted. + // EchoRequest is the number of ICMPv6 echo request packets counted. EchoRequest *StatCounter - // EchoReply is the total number of ICMPv6 echo reply packets counted. + // EchoReply is the number of ICMPv6 echo reply packets counted. EchoReply *StatCounter - // DstUnreachable is the total number of ICMPv6 destination unreachable - // packets counted. + // DstUnreachable is the number of ICMPv6 destination unreachable packets + // counted. DstUnreachable *StatCounter - // PacketTooBig is the total number of ICMPv6 packet too big packets - // counted. + // PacketTooBig is the number of ICMPv6 packet too big packets counted. PacketTooBig *StatCounter - // TimeExceeded is the total number of ICMPv6 time exceeded packets - // counted. + // TimeExceeded is the number of ICMPv6 time exceeded packets counted. TimeExceeded *StatCounter - // ParamProblem is the total number of ICMPv6 parameter problem packets - // counted. + // ParamProblem is the number of ICMPv6 parameter problem packets counted. ParamProblem *StatCounter - // RouterSolicit is the total number of ICMPv6 router solicit packets - // counted. + // RouterSolicit is the number of ICMPv6 router solicit packets counted. RouterSolicit *StatCounter - // RouterAdvert is the total number of ICMPv6 router advert packets - // counted. + // RouterAdvert is the number of ICMPv6 router advert packets counted. RouterAdvert *StatCounter - // NeighborSolicit is the total number of ICMPv6 neighbor solicit - // packets counted. + // NeighborSolicit is the number of ICMPv6 neighbor solicit packets counted. NeighborSolicit *StatCounter - // NeighborAdvert is the total number of ICMPv6 neighbor advert packets - // counted. + // NeighborAdvert is the number of ICMPv6 neighbor advert packets counted. NeighborAdvert *StatCounter - // RedirectMsg is the total number of ICMPv6 redirect message packets - // counted. + // RedirectMsg is the number of ICMPv6 redirect message packets counted. RedirectMsg *StatCounter - // MulticastListenerQuery is the total number of Multicast Listener Query - // messages counted. + // MulticastListenerQuery is the number of Multicast Listener Query messages + // counted. MulticastListenerQuery *StatCounter - // MulticastListenerReport is the total number of Multicast Listener Report - // messages counted. + // MulticastListenerReport is the number of Multicast Listener Report messages + // counted. MulticastListenerReport *StatCounter - // MulticastListenerDone is the total number of Multicast Listener Done - // messages counted. + // MulticastListenerDone is the number of Multicast Listener Done messages + // counted. MulticastListenerDone *StatCounter // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6PacketStats) @@ -1411,12 +1395,11 @@ type ICMPv6SentPacketStats struct { ICMPv6PacketStats - // Dropped is the total number of ICMPv6 packets dropped due to link - // layer errors. + // Dropped is the number of ICMPv6 packets dropped due to link layer errors. Dropped *StatCounter - // RateLimited is the total number of ICMPv6 packets dropped due to - // rate limit being exceeded. + // RateLimited is the number of ICMPv6 packets dropped due to rate limit being + // exceeded. RateLimited *StatCounter // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6SentPacketStats) @@ -1428,15 +1411,15 @@ type ICMPv6ReceivedPacketStats struct { ICMPv6PacketStats - // Unrecognized is the total number of ICMPv6 packets received that the - // transport layer does not know how to parse. + // Unrecognized is the number of ICMPv6 packets received that the transport + // layer does not know how to parse. Unrecognized *StatCounter - // Invalid is the total number of invalid ICMPv6 packets received. + // Invalid is the number of invalid ICMPv6 packets received. Invalid *StatCounter - // RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets - // dropped due to being router-specific packets. + // RouterOnlyPacketsDroppedByHost is the number of ICMPv6 packets dropped due + // to being router-specific packets. RouterOnlyPacketsDroppedByHost *StatCounter // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6ReceivedPacketStats) @@ -1468,18 +1451,18 @@ type ICMPStats struct { type IGMPPacketStats struct { // LINT.IfChange(IGMPPacketStats) - // MembershipQuery is the total number of Membership Query messages counted. + // MembershipQuery is the number of Membership Query messages counted. MembershipQuery *StatCounter - // V1MembershipReport is the total number of Version 1 Membership Report - // messages counted. + // V1MembershipReport is the number of Version 1 Membership Report messages + // counted. V1MembershipReport *StatCounter - // V2MembershipReport is the total number of Version 2 Membership Report - // messages counted. + // V2MembershipReport is the number of Version 2 Membership Report messages + // counted. V2MembershipReport *StatCounter - // LeaveGroup is the total number of Leave Group messages counted. + // LeaveGroup is the number of Leave Group messages counted. LeaveGroup *StatCounter // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPPacketStats) @@ -1491,7 +1474,7 @@ type IGMPSentPacketStats struct { IGMPPacketStats - // Dropped is the total number of IGMP packets dropped. + // Dropped is the number of IGMP packets dropped. Dropped *StatCounter // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPSentPacketStats) @@ -1503,15 +1486,14 @@ type IGMPReceivedPacketStats struct { IGMPPacketStats - // Invalid is the total number of invalid IGMP packets received. + // Invalid is the number of invalid IGMP packets received. Invalid *StatCounter - // ChecksumErrors is the total number of IGMP packets dropped due to bad - // checksums. + // ChecksumErrors is the number of IGMP packets dropped due to bad checksums. ChecksumErrors *StatCounter - // Unrecognized is the total number of unrecognized messages counted, these - // are silently ignored for forward-compatibilty. + // Unrecognized is the number of unrecognized messages counted, these are + // silently ignored for forward-compatibilty. Unrecognized *StatCounter // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPReceivedPacketStats) @@ -1534,51 +1516,50 @@ type IGMPStats struct { type IPStats struct { // LINT.IfChange(IPStats) - // PacketsReceived is the total number of IP packets received from the - // link layer. + // PacketsReceived is the number of IP packets received from the link layer. PacketsReceived *StatCounter - // DisabledPacketsReceived is the total number of IP packets received from the - // link layer when the IP layer is disabled. + // DisabledPacketsReceived is the number of IP packets received from the link + // layer when the IP layer is disabled. DisabledPacketsReceived *StatCounter - // InvalidDestinationAddressesReceived is the total number of IP packets - // received with an unknown or invalid destination address. + // InvalidDestinationAddressesReceived is the number of IP packets received + // with an unknown or invalid destination address. InvalidDestinationAddressesReceived *StatCounter - // InvalidSourceAddressesReceived is the total number of IP packets received - // with a source address that should never have been received on the wire. + // InvalidSourceAddressesReceived is the number of IP packets received with a + // source address that should never have been received on the wire. InvalidSourceAddressesReceived *StatCounter - // PacketsDelivered is the total number of incoming IP packets that - // are successfully delivered to the transport layer. + // PacketsDelivered is the number of incoming IP packets that are successfully + // delivered to the transport layer. PacketsDelivered *StatCounter - // PacketsSent is the total number of IP packets sent via WritePacket. + // PacketsSent is the number of IP packets sent via WritePacket. PacketsSent *StatCounter - // OutgoingPacketErrors is the total number of IP packets which failed - // to write to a link-layer endpoint. + // OutgoingPacketErrors is the number of IP packets which failed to write to a + // link-layer endpoint. OutgoingPacketErrors *StatCounter - // MalformedPacketsReceived is the total number of IP Packets that were - // dropped due to the IP packet header failing validation checks. + // MalformedPacketsReceived is the number of IP Packets that were dropped due + // to the IP packet header failing validation checks. MalformedPacketsReceived *StatCounter - // MalformedFragmentsReceived is the total number of IP Fragments that were - // dropped due to the fragment failing validation checks. + // MalformedFragmentsReceived is the number of IP Fragments that were dropped + // due to the fragment failing validation checks. MalformedFragmentsReceived *StatCounter - // IPTablesPreroutingDropped is the total number of IP packets dropped - // in the Prerouting chain. + // IPTablesPreroutingDropped is the number of IP packets dropped in the + // Prerouting chain. IPTablesPreroutingDropped *StatCounter - // IPTablesInputDropped is the total number of IP packets dropped in - // the Input chain. + // IPTablesInputDropped is the number of IP packets dropped in the Input + // chain. IPTablesInputDropped *StatCounter - // IPTablesOutputDropped is the total number of IP packets dropped in - // the Output chain. + // IPTablesOutputDropped is the number of IP packets dropped in the Output + // chain. IPTablesOutputDropped *StatCounter // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 0cb9d034e..38c2f321b 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -135,14 +135,15 @@ func TestForwarding(t *testing.T) { name string proto tcpip.TransportProtocolNumber expectedConnectErr tcpip.Error - setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + setupServer func(t *testing.T, ep tcpip.Endpoint) + setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) needRemoteAddr bool }{ { name: "UDP", proto: udp.ProtocolNumber, expectedConnectErr: nil, - setupServerSide: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() if err := ep.Connect(clientAddr); err != nil { @@ -156,12 +157,16 @@ func TestForwarding(t *testing.T) { name: "TCP", proto: tcp.ProtocolNumber, expectedConnectErr: &tcpip.ErrConnectStarted{}, - setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + setupServer: func(t *testing.T, ep tcpip.Endpoint) { t.Helper() if err := ep.Listen(1); err != nil { t.Fatalf("ep.Listen(1): %s", err) } + }, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var addr tcpip.FullAddress for { newEP, wq, err := ep.Accept(&addr) @@ -214,6 +219,9 @@ func TestForwarding(t *testing.T) { t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) } + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } { err := epsAndAddrs.clientEP.Connect(serverAddr) if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { @@ -229,7 +237,7 @@ func TestForwarding(t *testing.T) { serverEP := epsAndAddrs.serverEP serverCH := epsAndAddrs.serverReadableCH - if ep, ch := subTest.setupServerSide(t, serverEP, serverCH, clientAddr); ep != nil { + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil { defer ep.Close() serverEP = ep serverCH = ch @@ -256,13 +264,20 @@ func TestForwarding(t *testing.T) { read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { t.Helper() - // Wait for the endpoint to be readable. - <-ch var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break } readResult := tcpip.ReadResult{ diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 165f73f21..095623789 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -675,9 +675,7 @@ func TestWritePacketsLinkResolution(t *testing.T) { Length: length, }) xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) pkts.PushBack(pkt) @@ -1169,53 +1167,53 @@ func TestDAD(t *testing.T) { } tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - dadNetProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedResolved bool + name string + netProto tcpip.NetworkProtocolNumber + dadNetProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedResult stack.DADResult }{ { - name: "IPv4 own address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - expectedResolved: true, + name: "IPv4 own address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, + expectedResult: &stack.DADSucceeded{}, }, { - name: "IPv6 own address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - expectedResolved: true, + name: "IPv6 own address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, + expectedResult: &stack.DADSucceeded{}, }, { - name: "IPv4 duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedResolved: false, + name: "IPv4 duplicate address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2}, }, { - name: "IPv6 duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedResolved: false, + name: "IPv6 duplicate address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2}, }, { - name: "IPv4 no duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedResolved: true, + name: "IPv4 no duplicate address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, + expectedResult: &stack.DADSucceeded{}, }, { - name: "IPv6 no duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedResolved: true, + name: "IPv6 no duplicate address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, + expectedResult: &stack.DADSucceeded{}, }, } @@ -1262,7 +1260,7 @@ func TestDAD(t *testing.T) { } expectResults := 1 - if test.expectedResolved { + if _, ok := test.expectedResult.(*stack.DADSucceeded); ok { const delta = time.Nanosecond clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta) select { @@ -1287,7 +1285,7 @@ func TestDAD(t *testing.T) { } for i := 0; i < expectResults; i++ { - if diff := cmp.Diff(stack.DADResult{Resolved: test.expectedResolved}, <-ch); diff != "" { + if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" { t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) } } diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index c56155ea2..80afc2825 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -38,7 +38,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) type ndpDispatcher struct{} -func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { +func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index e4439ba79..29266a4fc 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -75,7 +75,11 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetType(header.ICMPv6EchoRequest) pkt.SetCode(0) pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, utils.RemoteIPv6Addr, dst, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: utils.RemoteIPv6Addr, + Dst: dst, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index f5e1a6e45..06c63e74a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -26,6 +26,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// TODO(https://gvisor.dev/issues/5623): Unit test this package. + // +stateify savable type icmpPacket struct { icmpPacketEntry @@ -414,15 +416,27 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return &tcpip.ErrInvalidEndpointState{} } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest + icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - pkt.Data = data.ToVectorisedView() + pkt.Data().AppendView(data) if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) + + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() + return err + } + + sentStat.Increment() + return nil } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { @@ -444,15 +458,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return &tcpip.ErrInvalidEndpointState{} } - - dataVV := data.ToVectorisedView() - icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) - pkt.Data = dataVV + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest + + pkt.Data().AppendView(data) + dataRange := pkt.Data().AsRange() + icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpv6, + Src: r.LocalAddress, + Dst: r.RemoteAddress, + PayloadCsum: dataRange.Checksum(), + PayloadLen: dataRange.Size(), + })) if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) + + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() + } + + sentStat.Increment() + return nil } // checkV4MappedLocked determines the effective network protocol and converts @@ -763,7 +793,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB // ICMP socket's data includes ICMP header. packet.data = pkt.TransportHeader().View().ToVectorisedView() - packet.data.Append(pkt.Data) + packet.data.Append(pkt.Data().ExtractVV()) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 73bb66830..367757d3b 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -432,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Cooked packets can simply be queued. switch pkt.PktType { case tcpip.PacketHost: - packet.data = pkt.Data + packet.data = pkt.Data().ExtractVV() case tcpip.PacketOutgoing: // Strip Link Header. var combinedVV buffer.VectorisedView @@ -442,7 +442,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, if v := pkt.TransportHeader().View(); !v.IsEmpty() { combinedVV.AppendView(v) } - combinedVV.Append(pkt.Data) + combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV default: panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) @@ -468,7 +468,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) } combinedVV := linkHeader.ToVectorisedView() - combinedVV.Append(pkt.Data) + combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV } else { packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index fe8e9c751..2709be90c 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -644,7 +644,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } else { combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() } - combinedVV.Append(pkt.Data) + combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV packet.timestampNS = e.stack.Clock().NowNanoseconds() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index fcdd032c5..a69d6624d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -105,7 +105,6 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp/testing/context", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 842c1622b..3b574837c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -432,15 +433,16 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // * e.mu is held. func (e *endpoint) reserveTupleLocked() bool { dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} - if !e.stack.ReserveTuple( - e.effectiveNetProtos, - ProtocolNumber, - e.ID.LocalAddress, - e.ID.LocalPort, - e.boundPortFlags, - e.boundBindToDevice, - dest, - ) { + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: dest, + } + if !e.stack.ReserveTuple(portRes) { return false } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 461b1a9d7..3404af6bb 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -68,7 +68,7 @@ type handshake struct { ep *endpoint state handshakeState active bool - flags uint8 + flags header.TCPFlags ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. @@ -606,7 +606,7 @@ func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer func (bt *backoffTimer) reset() tcpip.Error { bt.timeout *= 2 - if bt.timeout > MaxRTO { + if bt.timeout > bt.maxTimeout { return &tcpip.ErrTimeout{} } bt.t.Reset(bt.timeout) @@ -700,7 +700,7 @@ type tcpFields struct { id stack.TransportEndpointID ttl uint8 tos uint8 - flags byte + flags header.TCPFlags seq seqnum.Value ack seqnum.Value rcvWnd seqnum.Size @@ -752,7 +752,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } } @@ -786,7 +786,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso }) pkt.Hash = tf.txHash pkt.Owner = owner - data.ReadToVV(&pkt.Data, packetSize) + pkt.Data().ReadFromVV(&data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) pkts.PushBack(pkt) @@ -877,7 +877,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendRaw sends a TCP segment to the endpoint's peer. -func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f47b39ccc..129f36d11 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -760,6 +760,7 @@ func (e *endpoint) LockUser() { // protocol goroutine altogether. // // Precondition: e.LockUser() must have been called before calling e.UnlockUser() +// +checklocks:e.mu func (e *endpoint) UnlockUser() { // Lock segment queue before checking so that we avoid a race where // segments can be queued between the time we check if queue is empty @@ -800,6 +801,7 @@ func (e *endpoint) StopWork() { } // ResumeWork resumes packet processing. Only to be used in tests. +// +checklocks:e.mu func (e *endpoint) ResumeWork() { e.mu.Unlock() } @@ -1095,7 +1097,16 @@ func (e *endpoint) closeNoShutdownLocked() { e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: e.boundDest, + } + e.stack.ReleasePort(portRes) e.isPortReserved = false e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -1170,7 +1181,16 @@ func (e *endpoint) cleanupLocked() { } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: e.boundDest, + } + e.stack.ReleasePort(portRes) e.isPortReserved = false } e.boundBindToDevice = 0 @@ -2218,7 +2238,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) h.Write(portBuf) - portOffset := h.Sum32() + portOffset := uint16(h.Sum32()) var twReuse tcpip.TCPTimeWaitReuseOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { @@ -2240,7 +2260,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: p, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: addr, + } + if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } @@ -2278,7 +2307,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: p, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: addr, + } + if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { return false, nil } } @@ -2286,7 +2324,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp id := e.ID id.LocalPort = p if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: p, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: addr, + } + e.stack.ReleasePort(portRes) if _, ok := err.(*tcpip.ErrPortInUse); ok { return false, nil } @@ -2602,7 +2649,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: addr.Addr, + Port: addr.Port, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: tcpip.FullAddress{}, + } + port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { id := e.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2614,9 +2670,9 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { // address/port. Hence this will only return an error if there is a matching // listening endpoint. if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { - return false + return false, nil } - return true + return true, nil }) if err != nil { return err @@ -2699,7 +2755,7 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p Cause: transErr, // Linux passes the payload with the TCP header. We don't know if the TCP // header even exists, it may not for fragmented packets. - Payload: pkt.Data.ToView(), + Payload: pkt.Data().AsRange().ToOwnedView(), Dst: tcpip.FullAddress{ NIC: pkt.NICID, Addr: e.ID.RemoteAddress, diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index e4368026f..a53d76917 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -22,9 +22,11 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// +checklocks:e.mu func (e *endpoint) drainSegmentLocked() { // Drain only up to once. if e.drainDone != nil { @@ -207,7 +209,16 @@ func (e *endpoint) Resume(s *stack.Stack) { if err != nil { panic("unable to parse BindAddr: " + err.String()) } - if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok { + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: addr.Addr, + Port: addr.Port, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: e.boundDest, + } + if ok := e.stack.ReserveTuple(portRes); !ok { panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) } e.isPortReserved = true diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 04012cd40..2a4667906 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -226,7 +226,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) - flags := byte(header.TCPFlagRst) + flags := header.TCPFlagRst // As per RFC 793 page 35 (Reset Generation) // 1. If the connection does not exist (CLOSED) then a reset is sent // in response to any incoming segment except another reset. In diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index f27eef6a9..8edd6775b 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -62,7 +62,7 @@ type segment struct { views [8]buffer.View `state:"nosave"` sequenceNumber seqnum.Value ackNumber seqnum.Value - flags uint8 + flags header.TCPFlags window seqnum.Size // csum is only populated for received segments. csum uint16 @@ -98,7 +98,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * netProto: pkt.NetworkProtocolNumber, nicID: pkt.NICID, } - s.data = pkt.Data.Clone(s.views[:]) + s.data = pkt.Data().ExtractVV().Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() s.dataMemSize = s.data.Size() @@ -141,12 +141,12 @@ func (s *segment) clone() *segment { } // flagIsSet checks if at least one flag in flags is set in s.flags. -func (s *segment) flagIsSet(flags uint8) bool { +func (s *segment) flagIsSet(flags header.TCPFlags) bool { return s.flags&flags != 0 } // flagsAreSet checks if all flags in flags are set in s.flags. -func (s *segment) flagsAreSet(flags uint8) bool { +func (s *segment) flagsAreSet(flags header.TCPFlags) bool { return s.flags&flags == flags } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 83c8deb0e..18817029d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1613,7 +1613,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. -func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error { +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0128c1f7e..fd499a47b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -33,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -1373,7 +1372,7 @@ func TestTOSV4(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), ) @@ -1421,7 +1420,7 @@ func TestTrafficClassV6(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), ) @@ -2202,7 +2201,7 @@ func TestSimpleSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2242,7 +2241,7 @@ func TestZeroWindowSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2264,7 +2263,7 @@ func TestZeroWindowSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2311,7 +2310,7 @@ func TestScaledWindowConnect(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2342,7 +2341,7 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2415,7 +2414,7 @@ func TestScaledWindowAccept(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2488,7 +2487,7 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2666,7 +2665,7 @@ func TestSegmentMerging(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2689,7 +2688,7 @@ func TestSegmentMerging(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+11), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2738,7 +2737,7 @@ func TestDelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2786,7 +2785,7 @@ func TestUndelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2809,7 +2808,7 @@ func TestUndelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2872,7 +2871,7 @@ func TestMSSNotDelayed(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2923,7 +2922,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3438,7 +3437,7 @@ func TestMaxRTO(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) const numRetransmits = 2 @@ -3447,7 +3446,7 @@ func TestMaxRTO(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { @@ -3490,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.FragmentFlags(0), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} @@ -3502,7 +3501,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.FragmentFlags(0), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) id := header.IPv4(pkt).ID() @@ -3633,7 +3632,7 @@ func TestFinWithNoPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3710,7 +3709,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3729,7 +3728,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3796,7 +3795,7 @@ func TestFinWithPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3822,7 +3821,7 @@ func TestFinWithPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3886,7 +3885,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3907,7 +3906,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3923,7 +3922,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -4033,7 +4032,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -4783,7 +4782,8 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("unknown address type: '%s'", candidateAddressType) } - for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { + start, end := s.PortRange() + for i := start; i <= end; i++ { if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { t.Fatalf("Bind(%d) failed: %s", i, err) } @@ -4844,7 +4844,7 @@ func TestPathMTUDiscovery(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(seqNum), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) seqNum += uint32(size) @@ -5129,7 +5129,7 @@ func TestKeepalive(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7174,7 +7174,7 @@ func TestTCPCloseWithData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7274,7 +7274,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 5a9745ad7..cb4f82903 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -170,7 +170,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), checker.TCPTimestampChecker(true, 0, tsVal+1), ), ) @@ -231,7 +231,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), checker.TCPTimestampChecker(false, 0, 0), ), ) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b1cb9a324..2f1c1011d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -101,7 +101,7 @@ type Headers struct { AckNum seqnum.Value // Flags are the TCP flags in the TCP header. - Flags int + Flags header.TCPFlags // RcvWnd is the window to be advertised in the ReceiveWindow field of // the TCP header. @@ -452,7 +452,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp SeqNum: uint32(h.SeqNum), AckNum: uint32(h.AckNum), DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: uint8(h.Flags), + Flags: h.Flags, WindowSize: uint16(h.RcvWnd), }) @@ -544,7 +544,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -571,7 +571,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -650,7 +650,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp SeqNum: uint32(h.SeqNum), AckNum: uint32(h.AckNum), DataOffset: header.TCPMinimumSize, - Flags: uint8(h.Flags), + Flags: h.Flags, WindowSize: uint16(h.RcvWnd), }) @@ -780,7 +780,7 @@ type RawEndpoint struct { C *Context SrcPort uint16 DstPort uint16 - Flags int + Flags header.TCPFlags NextSeqNum seqnum.Value AckNum seqnum.Value WndSize seqnum.Size diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go index 5e271b7ca..6c5ddc3c7 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go @@ -465,7 +465,7 @@ func TestIgnoreBadResetOnSynSent(t *testing.T) { // Receive a RST with a bad ACK, it should not cause the connection to // be reset. acks := []uint32{1234, 1236, 1000, 5000} - flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} + flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} for _, a := range acks { for _, f := range flags { tcp.Encode(&header.TCPFields{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 807df2bb5..c0f566459 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -245,7 +245,16 @@ func (e *endpoint) Close() { switch e.EndpointState() { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: tcpip.FullAddress{}, + } + e.stack.ReleasePort(portRes) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} } @@ -920,7 +929,16 @@ func (e *endpoint) Disconnect() tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: tcpip.FullAddress{}, + } + e.stack.ReleasePort(portRes) e.boundPortFlags = ports.Flags{} } e.setEndpointState(StateInitial) @@ -1072,7 +1090,16 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: id.LocalAddress, + Port: id.LocalPort, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: tcpip.FullAddress{}, + } + port, err := e.stack.ReservePort(portRes, nil /* testPort */) if err != nil { return id, bindToDevice, err } @@ -1082,7 +1109,16 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: id.LocalAddress, + Port: id.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: bindToDevice, + Dest: tcpip.FullAddress{}, + } + e.stack.ReleasePort(portRes) e.boundPortFlags = ports.Flags{} } return id, bindToDevice, err @@ -1227,7 +1263,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { netHdr := pkt.Network() xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Data().Views() { xsum = header.Checksum(v, xsum) } return hdr.CalculateChecksum(xsum) == 0xffff @@ -1240,7 +1276,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) - if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { + if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1287,10 +1323,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB Addr: id.LocalAddress, Port: header.UDP(hdr).DestinationPort(), }, + data: pkt.Data().ExtractVV(), } - packet.data = pkt.Data e.rcvList.PushBack(packet) - e.rcvBufSize += pkt.Data.Size() + e.rcvBufSize += packet.data.Size() // Save any useful information from the network header to the packet. switch pkt.NetworkProtocolNumber { @@ -1327,7 +1363,7 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p if e.SocketOptions().GetRecvError() { // Linux passes the payload without the UDP header. var payload []byte - udp := header.UDP(pkt.Data.ToView()) + udp := header.UDP(pkt.Data().AsRange().ToOwnedView()) if len(udp) >= header.UDPMinimumSize { payload = udp.Payload() } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 427fdd0c9..1171aeb79 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -80,7 +80,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { // protocol but don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) - if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { + if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize { p.stack.Stats().UDP.MalformedPacketsReceived.Increment() return stack.UnknownDestinationPacketMalformed } |