diff options
Diffstat (limited to 'pkg')
117 files changed, 2667 insertions, 1868 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a461bb65e..29ead20d0 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -15,6 +15,7 @@ go_library( "bpf.go", "capability.go", "clone.go", + "context.go", "dev.go", "elf.go", "epoll.go", @@ -77,6 +78,7 @@ go_library( deps = [ "//pkg/abi", "//pkg/bits", + "//pkg/context", "//pkg/marshal", "//pkg/marshal/primitive", ], diff --git a/pkg/abi/linux/context.go b/pkg/abi/linux/context.go new file mode 100644 index 000000000..d2dbba183 --- /dev/null +++ b/pkg/abi/linux/context.go @@ -0,0 +1,36 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/context" +) + +// contextID is the linux package's type for context.Context.Value keys. +type contextID int + +const ( + // CtxSignalNoInfoFunc is a Context.Value key for a function to send signals. + CtxSignalNoInfoFunc contextID = iota +) + +// SignalNoInfoFuncFromContext returns a callback function that can be used to send a +// signal to the given context. +func SignalNoInfoFuncFromContext(ctx context.Context) func(Signal) error { + if f := ctx.Value(CtxSignalNoInfoFunc); f != nil { + return f.(func(Signal) error) + } + return nil +} diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go index c59c9c136..ea4fdca0f 100644 --- a/pkg/abi/linux/ioctl_tun.go +++ b/pkg/abi/linux/ioctl_tun.go @@ -26,4 +26,7 @@ const ( IFF_TAP = 0x0002 IFF_NO_PI = 0x1000 IFF_NOFILTER = 0x1000 + + // According to linux/if_tun.h "This flag has no real effect" + IFF_ONE_QUEUE = 0x2000 ) diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go index 74a55a123..514592b08 100644 --- a/pkg/crypto/crypto_stdlib.go +++ b/pkg/crypto/crypto_stdlib.go @@ -22,11 +22,11 @@ import ( // EcdsaVerify verifies the signature in r, s of hash using ECDSA and the // public key, pub. Its return value records whether the signature is valid. -func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { - return ecdsa.Verify(pub, hash, r, s) +func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) { + return ecdsa.Verify(pub, hash, r, s), nil } // SumSha384 returns the SHA384 checksum of the data. -func SumSha384(data []byte) (sum384 [sha512.Size384]byte) { - return sha512.Sum384(data) +func SumSha384(data []byte) ([sha512.Size384]byte, error) { + return sha512.Sum384(data), nil } diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go index b90f7c1be..4c99944e7 100644 --- a/pkg/sentry/fs/attr.go +++ b/pkg/sentry/fs/attr.go @@ -478,6 +478,20 @@ func (f FilePermissions) AnyRead() bool { return f.User.Read || f.Group.Read || f.Other.Read } +// HasSetUIDOrGID returns true if either the setuid or setgid bit is set. +func (f FilePermissions) HasSetUIDOrGID() bool { + return f.SetUID || f.SetGID +} + +// DropSetUIDAndMaybeGID turns off setuid, and turns off setgid if f allows +// group execution. +func (f *FilePermissions) DropSetUIDAndMaybeGID() { + f.SetUID = false + if f.Group.Execute { + f.SetGID = false + } +} + // FileOwner represents ownership of a file. // // +stateify savable diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index e84ba7a5d..c62effd52 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -16,6 +16,7 @@ package dev import ( + "fmt" "math" "gvisor.dev/gvisor/pkg/context" @@ -90,6 +91,11 @@ func newSymlink(ctx context.Context, target string, msrc *fs.MountSource) *fs.In // New returns the root node of a device filesystem. func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { + shm, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */) + if err != nil { + panic(fmt.Sprintf("tmpfs.NewDir failed: %v", err)) + } + contents := map[string]*fs.Inode{ "fd": newSymlink(ctx, "/proc/self/fd", msrc), "stdin": newSymlink(ctx, "/proc/self/fd/0", msrc), @@ -108,7 +114,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "random": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, randomDevMinor), "urandom": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, urandomDevMinor), - "shm": tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc), + "shm": shm, // A devpts is typically mounted at /dev/pts to provide // pseudoterminal support. Place an empty directory there for diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index e1e38b498..8ac3738e9 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -155,12 +155,20 @@ func (h *HostMappable) DecRef(fr memmap.FileRange) { // T2: Appends to file causing it to grow // T2: Writes to mapped pages and COW happens // T1: Continues and wronly invalidates the page mapped in step above. -func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error { +func (h *HostMappable) Truncate(ctx context.Context, newSize int64, uattr fs.UnstableAttr) error { h.truncateMu.Lock() defer h.truncateMu.Unlock() mask := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: newSize} + + // Truncating a file clears privilege bits. + if uattr.Perms.HasSetUIDOrGID() { + mask.Perms = true + attr.Perms = uattr.Perms + attr.Perms.DropSetUIDAndMaybeGID() + } + if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil { return err } @@ -193,10 +201,17 @@ func (h *HostMappable) Allocate(ctx context.Context, offset int64, length int64) } // Write writes to the file backing this mappable. -func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64, uattr fs.UnstableAttr) (int64, error) { h.truncateMu.RLock() + defer h.truncateMu.RUnlock() n, err := src.CopyInTo(ctx, &writer{ctx: ctx, hostMappable: h, off: offset}) - h.truncateMu.RUnlock() + if n > 0 && uattr.Perms.HasSetUIDOrGID() { + mask := fs.AttrMask{Perms: true} + uattr.Perms.DropSetUIDAndMaybeGID() + if err := h.backingFile.SetMaskedAttributes(ctx, mask, uattr, false); err != nil { + return n, err + } + } return n, err } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 7856b354b..855029b84 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -310,6 +310,12 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode, now := ktime.NowFromContext(ctx) masked := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: size} + if c.attr.Perms.HasSetUIDOrGID() { + masked.Perms = true + attr.Perms = c.attr.Perms + attr.Perms.DropSetUIDAndMaybeGID() + c.attr.Perms = attr.Perms + } if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil { c.dataMu.Unlock() return err @@ -685,13 +691,14 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return done, nil } -// maybeGrowFile grows the file's size if data has been written past the old -// size. +// maybeUpdateAttrs updates the file's attributes after a write. It updates +// size if data has been written past the old size, and setuid/setgid if any +// bytes were written. // // Preconditions: // * rw.c.attrMu must be locked. // * rw.c.dataMu must be locked. -func (rw *inodeReadWriter) maybeGrowFile() { +func (rw *inodeReadWriter) maybeUpdateAttrs(nwritten uint64) { // If the write ends beyond the file's previous size, it causes the // file to grow. if rw.offset > rw.c.attr.Size { @@ -705,6 +712,12 @@ func (rw *inodeReadWriter) maybeGrowFile() { rw.c.attr.Usage = rw.offset rw.c.dirtyAttr.Usage = true } + + // If bytes were written, ensure setuid and setgid are cleared. + if nwritten > 0 && rw.c.attr.Perms.HasSetUIDOrGID() { + rw.c.dirtyAttr.Perms = true + rw.c.attr.Perms.DropSetUIDAndMaybeGID() + } } // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. @@ -732,7 +745,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error segMR := seg.Range().Intersect(mr) ims, err := mf.MapInternal(seg.FileRangeOf(segMR), hostarch.Write) if err != nil { - rw.maybeGrowFile() + rw.maybeUpdateAttrs(done) rw.c.dataMu.Unlock() return done, err } @@ -744,7 +757,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error srcs = srcs.DropFirst64(n) rw.c.dirty.MarkDirty(segMR) if err != nil { - rw.maybeGrowFile() + rw.maybeUpdateAttrs(done) rw.c.dataMu.Unlock() return done, err } @@ -765,7 +778,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error srcs = srcs.DropFirst64(n) // Partial writes are fine. But we must stop writing. if n != src.NumBytes() || err != nil { - rw.maybeGrowFile() + rw.maybeUpdateAttrs(done) rw.c.dataMu.Unlock() return done, err } @@ -774,7 +787,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error seg, gap = gap.NextSegment(), FileRangeGapIterator{} } } - rw.maybeGrowFile() + rw.maybeUpdateAttrs(done) rw.c.dataMu.Unlock() return done, nil } diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 819e140bc..73d80d9b5 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -237,10 +237,20 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I // and availability of a host-mappable FD. if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) { n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset) - } else if f.inodeOperations.fileState.hostMappable != nil { - n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset) } else { - n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) + uattr, e := f.UnstableAttr(ctx, file) + if e != nil { + return 0, e + } + if f.inodeOperations.fileState.hostMappable != nil { + n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset, uattr) + } else { + n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) + if n > 0 && uattr.Perms.HasSetUIDOrGID() { + uattr.Perms.DropSetUIDAndMaybeGID() + f.inodeOperations.SetPermissions(ctx, file.Dirent.Inode, uattr.Perms) + } + } } if n == 0 { diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index b97635ec4..da3178527 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -600,11 +600,25 @@ func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length if i.session().cachePolicy.useCachingInodeOps(inode) { return i.cachingInodeOps.Truncate(ctx, inode, length) } + + uattr, err := i.fileState.unstableAttr(ctx) + if err != nil { + return err + } + if i.session().cachePolicy == cacheRemoteRevalidating { - return i.fileState.hostMappable.Truncate(ctx, length) + return i.fileState.hostMappable.Truncate(ctx, length, uattr) + } + + mask := p9.SetAttrMask{Size: true} + attr := p9.SetAttr{Size: uint64(length)} + if uattr.Perms.HasSetUIDOrGID() { + mask.Permissions = true + uattr.Perms.DropSetUIDAndMaybeGID() + attr.Permissions = p9.FileMode(uattr.Perms.LinuxMode()) } - return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)}) + return i.fileState.file.setAttr(ctx, mask, attr) } // GetXattr implements fs.InodeOperations.GetXattr. diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 6b3627813..940838a44 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -130,7 +130,16 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string panic(fmt.Sprintf("Create called with unknown or unset open flags: %v", flags)) } + // If the parent directory has setgid enabled, change the new file's owner. owner := fs.FileOwnerFromContext(ctx) + parentUattr, err := dir.UnstableAttr(ctx) + if err != nil { + return nil, err + } + if parentUattr.Perms.SetGID { + owner.GID = parentUattr.Owner.GID + } + hostFile, err := newFile.create(ctx, name, openFlags, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)) if err != nil { // Could not create the file. @@ -225,7 +234,18 @@ func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s return syserror.ENAMETOOLONG } + // If the parent directory has setgid enabled, change the new directory's + // owner and enable setgid. owner := fs.FileOwnerFromContext(ctx) + parentUattr, err := dir.UnstableAttr(ctx) + if err != nil { + return err + } + if parentUattr.Perms.SetGID { + owner.GID = parentUattr.Owner.GID + perm.SetGID = true + } + if _, err := i.fileState.file.mkdir(ctx, s, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)); err != nil { return err } diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go index bc117ca6a..b48d475ed 100644 --- a/pkg/sentry/fs/tmpfs/fs.go +++ b/pkg/sentry/fs/tmpfs/fs.go @@ -151,5 +151,5 @@ func (f *Filesystem) Mount(ctx context.Context, device string, flags fs.MountSou } // Construct the tmpfs root. - return NewDir(ctx, nil, owner, perms, msrc), nil + return NewDir(ctx, nil, owner, perms, msrc, nil /* parent */) } diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index f4de8c968..7faa822f0 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -226,6 +226,12 @@ func (f *fileInodeOperations) Truncate(ctx context.Context, _ *fs.Inode, size in now := ktime.NowFromContext(ctx) f.attr.ModificationTime = now f.attr.StatusChangeTime = now + + // Truncating clears privilege bits. + f.attr.Perms.SetUID = false + if f.attr.Perms.Group.Execute { + f.attr.Perms.SetGID = false + } } f.dataMu.Unlock() @@ -363,7 +369,14 @@ func (f *fileInodeOperations) write(ctx context.Context, src usermem.IOSequence, now := ktime.NowFromContext(ctx) f.attr.ModificationTime = now f.attr.StatusChangeTime = now - return src.CopyInTo(ctx, &fileReadWriter{f, offset}) + nwritten, err := src.CopyInTo(ctx, &fileReadWriter{f, offset}) + + // Writing clears privilege bits. + if nwritten > 0 { + f.attr.Perms.DropSetUIDAndMaybeGID() + } + + return nwritten, err } type fileReadWriter struct { diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 577052888..6aa8ff331 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -87,7 +87,20 @@ type Dir struct { var _ fs.InodeOperations = (*Dir)(nil) // NewDir returns a new directory. -func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { +func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource, parent *fs.Inode) (*fs.Inode, error) { + // If the parent has setgid enabled, the new directory enables it and changes + // its GID. + if parent != nil { + parentUattr, err := parent.UnstableAttr(ctx) + if err != nil { + return nil, err + } + if parentUattr.Perms.SetGID { + owner.GID = parentUattr.Owner.GID + perms.SetGID = true + } + } + d := &Dir{ ramfsDir: ramfs.NewDir(ctx, contents, owner, perms), kernel: kernel.KernelFromContext(ctx), @@ -101,7 +114,7 @@ func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwn InodeID: tmpfsDevice.NextIno(), BlockSize: hostarch.PageSize, Type: fs.Directory, - }) + }), nil } // afterLoad is invoked by stateify. @@ -219,11 +232,21 @@ func (d *Dir) SetTimestamps(ctx context.Context, i *fs.Inode, ts fs.TimeSpec) er func (d *Dir) newCreateOps() *ramfs.CreateOps { return &ramfs.CreateOps{ NewDir: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) { - return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil + return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource, dir) }, NewFile: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) { + // If the parent has setgid enabled, change the GID of the new file. + owner := fs.FileOwnerFromContext(ctx) + parentUattr, err := dir.UnstableAttr(ctx) + if err != nil { + return nil, err + } + if parentUattr.Perms.SetGID { + owner.GID = parentUattr.Owner.GID + } + uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{ - Owner: fs.FileOwnerFromContext(ctx), + Owner: owner, Perms: perms, // Always start unlinked. Links: 0, diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go index 12b786224..7f8fa8038 100644 --- a/pkg/sentry/fs/user/user_test.go +++ b/pkg/sentry/fs/user/user_test.go @@ -104,7 +104,10 @@ func TestGetExecUserHome(t *testing.T) { t.Run(name, func(t *testing.T) { ctx := contexttest.Context(t) msrc := fs.NewPseudoMountSource(ctx) - rootInode := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc) + rootInode, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */) + if err != nil { + t.Fatalf("tmpfs.NewDir failed: %v", err) + } mns, err := fs.NewMountNamespace(ctx, rootInode) if err != nil { diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 43bfd69a3..82491a0f8 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -207,9 +207,10 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e return err } - // Changing owners may clear one or both of the setuid and setgid bits, - // so we may have to update opts before setting d.mode. - if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 { + // Changing owners or truncating may clear one or both of the setuid and + // setgid bits, so we may have to update opts before setting d.mode. + inotifyMask := opts.Stat.Mask + if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{ Mask: linux.STATX_MODE, }) @@ -218,10 +219,14 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e } opts.Stat.Mode = stat.Mode opts.Stat.Mask |= linux.STATX_MODE + // Don't generate inotify IN_ATTRIB for size-only changes (truncations). + if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 { + inotifyMask |= linux.STATX_MODE + } } d.updateAfterSetStatLocked(&opts) - if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + if ev := vfs.InotifyEventFromStatMask(inotifyMask); ev != 0 { d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 33e52ce64..438840ae2 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -88,6 +88,7 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating timekeeper: %v", err) } tk.SetClocks(time.NewCalibratedClocks()) + k.SetTimekeeper(tk) creds := auth.NewRootCredentials(auth.NewRootUserNamespace()) @@ -96,7 +97,6 @@ func Boot() (*kernel.Kernel, error) { if err = k.Init(kernel.InitKernelArgs{ ApplicationCores: uint(runtime.GOMAXPROCS(-1)), FeatureSet: cpuid.HostFeatureSet(), - Timekeeper: tk, RootUserNamespace: creds.UserNamespace, Vdso: vdso, RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace), @@ -181,7 +181,7 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) { memfile := os.NewFile(uintptr(memfd), memfileName) mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) if err != nil { - memfile.Close() + _ = memfile.Close() return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err) } return mf, nil diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index fa7696ad6..969003613 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -868,6 +868,10 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa fd.mu.Lock() defer fd.mu.Unlock() + if _, err := fd.lowerFD.Seek(ctx, fd.off, linux.SEEK_SET); err != nil { + return err + } + var ds []vfs.Dirent err := fd.lowerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { // Do not include the Merkle tree files. @@ -890,8 +894,8 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa return err } - // The result should contain all children plus "." and "..". - if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 { + // The result should be a part of all children plus "." and "..", counting from fd.off. + if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2-int(fd.off) { return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) } diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go index 0fbf27f64..c93ef6ac1 100644 --- a/pkg/sentry/kernel/cgroup.go +++ b/pkg/sentry/kernel/cgroup.go @@ -181,7 +181,23 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files for _, h := range r.hierarchies { if h.match(ctypes) { - h.fs.IncRef() + if !h.fs.TryIncRef() { + // Racing with filesystem destruction, namely h.fs.Release. + // Since we hold r.mu, we know the hierarchy hasn't been + // unregistered yet, but its associated filesystem is tearing + // down. + // + // If we simply indicate the hierarchy wasn't found without + // cleaning up the registry, the caller can race with the + // unregister and find itself temporarily unable to create a new + // hierarchy with a subset of the relevant controllers. + // + // To keep the result of FindHierarchy consistent with the + // uniqueness of controllers enforced by Register, drop the + // dying hierarchy now. The eventual unregister by the FS + // teardown will become a no-op. + return nil + } return h.fs } } @@ -230,12 +246,17 @@ func (r *CgroupRegistry) Register(cs []CgroupController, fs cgroupFS) error { return nil } -// Unregister removes a previously registered hierarchy from the registry. If -// the controller was not previously registered, Unregister is a no-op. +// Unregister removes a previously registered hierarchy from the registry. If no +// such hierarchy is registered, Unregister is a no-op. func (r *CgroupRegistry) Unregister(hid uint32) { r.mu.Lock() - defer r.mu.Unlock() + r.unregisterLocked(hid) + r.mu.Unlock() +} +// Precondition: Caller must hold r.mu. +// +checklocks:r.mu +func (r *CgroupRegistry) unregisterLocked(hid uint32) { if h, ok := r.hierarchies[hid]; ok { for name, _ := range h.controllers { delete(r.controllers, name) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index e6e9da898..febe7fe50 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -306,9 +306,6 @@ type InitKernelArgs struct { // FeatureSet is the emulated CPU feature set. FeatureSet *cpuid.FeatureSet - // Timekeeper manages time for all tasks in the system. - Timekeeper *Timekeeper - // RootUserNamespace is the root user namespace. RootUserNamespace *auth.UserNamespace @@ -348,29 +345,34 @@ type InitKernelArgs struct { PIDNamespace *PIDNamespace } +// SetTimekeeper sets Kernel.timekeeper. SetTimekeeper must be called before +// Init. +func (k *Kernel) SetTimekeeper(tk *Timekeeper) { + k.timekeeper = tk +} + // Init initialize the Kernel with no tasks. // // Callers must manually set Kernel.Platform and call Kernel.SetMemoryFile -// before calling Init. +// and Kernel.SetTimekeeper before calling Init. func (k *Kernel) Init(args InitKernelArgs) error { if args.FeatureSet == nil { - return fmt.Errorf("FeatureSet is nil") + return fmt.Errorf("args.FeatureSet is nil") } - if args.Timekeeper == nil { - return fmt.Errorf("Timekeeper is nil") + if k.timekeeper == nil { + return fmt.Errorf("timekeeper is nil") } - if args.Timekeeper.clocks == nil { + if k.timekeeper.clocks == nil { return fmt.Errorf("must call Timekeeper.SetClocks() before Kernel.Init()") } if args.RootUserNamespace == nil { - return fmt.Errorf("RootUserNamespace is nil") + return fmt.Errorf("args.RootUserNamespace is nil") } if args.ApplicationCores == 0 { - return fmt.Errorf("ApplicationCores is 0") + return fmt.Errorf("args.ApplicationCores is 0") } k.featureSet = args.FeatureSet - k.timekeeper = args.Timekeeper k.tasks = newTaskSet(args.PIDNamespace) k.rootUserNamespace = args.RootUserNamespace k.rootUTSNamespace = args.RootUTSNamespace @@ -395,8 +397,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { } k.extraAuxv = args.ExtraAuxv k.vdso = args.Vdso - k.realtimeClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Realtime} - k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} + k.realtimeClock = &timekeeperClock{tk: k.timekeeper, c: sentrytime.Realtime} + k.monotonicClock = &timekeeperClock{tk: k.timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() k.ptraceExceptions = make(map[*Task]*Task) @@ -654,12 +656,12 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { defer k.tasks.mu.RUnlock() for t := range k.tasks.Root.tids { // We can skip locking Task.mu here since the kernel is paused. - if mm := t.image.MemoryManager; mm != nil { - if _, ok := invalidated[mm]; !ok { - if err := mm.InvalidateUnsavable(ctx); err != nil { + if memMgr := t.image.MemoryManager; memMgr != nil { + if _, ok := invalidated[memMgr]; !ok { + if err := memMgr.InvalidateUnsavable(ctx); err != nil { return err } - invalidated[mm] = struct{}{} + invalidated[memMgr] = struct{}{} } } // I really wish we just had a sync.Map of all MMs... @@ -1553,22 +1555,23 @@ func (k *Kernel) SetSaveError(err error) { var _ tcpip.Clock = (*Kernel)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (k *Kernel) NowNanoseconds() int64 { - now, err := k.timekeeper.GetTime(sentrytime.Realtime) +// Now implements tcpip.Clock.NowNanoseconds. +func (k *Kernel) Now() time.Time { + nsec, err := k.timekeeper.GetTime(sentrytime.Realtime) if err != nil { - panic("Kernel.NowNanoseconds: " + err.Error()) + panic("timekeeper.GetTime(sentrytime.Realtime): " + err.Error()) } - return now + return time.Unix(0, nsec) } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (k *Kernel) NowMonotonic() int64 { - now, err := k.timekeeper.GetTime(sentrytime.Monotonic) +func (k *Kernel) NowMonotonic() tcpip.MonotonicTime { + nsec, err := k.timekeeper.GetTime(sentrytime.Monotonic) if err != nil { - panic("Kernel.NowMonotonic: " + err.Error()) + panic("timekeeper.GetTime(sentrytime.Monotonic): " + err.Error()) } - return now + var mt tcpip.MonotonicTime + return mt.Add(time.Duration(nsec) * time.Nanosecond) } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -1783,7 +1786,7 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { }) t := TaskFromContext(ctx) - k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{ + _, _ = k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{ Tid: int32(t.ThreadID()), Registers: t.Arch().StateData().Proto(), }) @@ -1874,7 +1877,7 @@ func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) { return } t.mu.Lock() - for cg, _ := range t.cgroups { + for cg := range t.cgroups { if cg.HierarchyID() == hid { t.leaveCgroupLocked(cg) } diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 2d89b9ccd..24e467e93 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -86,6 +86,12 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) if n > 0 { p.Notify(waiter.ReadableEvents) } + if err == unix.EPIPE { + // If we are returning EPIPE send SIGPIPE to the task. + if sendSig := linux.SignalNoInfoFuncFromContext(ctx); sendSig != nil { + sendSig(linux.SIGPIPE) + } + } return n, err } diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index fe2ab1662..1d9edf118 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -35,10 +35,10 @@ const ( // Maximum number of semaphore sets. setsMax = linux.SEMMNI - // Maximum number of semaphroes in a semaphore set. + // Maximum number of semaphores in a semaphore set. semsMax = linux.SEMMSL - // Maximum number of semaphores in all semaphroe sets. + // Maximum number of semaphores in all semaphore sets. semsTotalMax = linux.SEMMNS ) @@ -220,7 +220,7 @@ func (r *Registry) HighestIndex() int32 { defer r.mu.Unlock() // By default, highest used index is 0 even though - // there is no semaphroe set. + // there is no semaphore set. var highestIndex int32 for index := range r.indexes { if index > highestIndex { @@ -702,7 +702,9 @@ func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool { return s.checkCapability(creds) } -// destroy destroys the set. Caller must hold 's.mu'. +// destroy destroys the set. +// +// Preconditions: Caller must hold 's.mu'. func (s *Set) destroy() { // Notify all waiters. They will fail on the next attempt to execute // operations and return error. diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 70b0699dc..c82d9e82b 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -17,6 +17,7 @@ package kernel import ( "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -113,6 +114,10 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { return t.k.RealtimeClock() case limits.CtxLimits: return t.tg.limits + case linux.CtxSignalNoInfoFunc: + return func(sig linux.Signal) error { + return t.SendSignal(SignalInfoNoInfo(sig, t, t)) + } case pgalloc.CtxMemoryFile: return t.k.mf case pgalloc.CtxMemoryFileProvider: diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index f42d73178..e1c4b06fc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -58,8 +58,8 @@ var nameToID = map[string]stack.TableID{ // DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for // compatibility with netfilter extensions. -func DefaultLinuxTables() *stack.IPTables { - tables := stack.DefaultTables() +func DefaultLinuxTables(seed uint32) *stack.IPTables { + tables := stack.DefaultTables(seed) tables.VisitTargets(func(oldTarget stack.Target) stack.Target { switch val := oldTarget.(type) { case *stack.AcceptTarget: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0b64a24c3..037ccfec8 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -77,41 +77,59 @@ func mustCreateGauge(name, description string) *tcpip.StatCounter { // Metrics contains metrics exported by netstack. var Metrics = tcpip.Stats{ - UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."), - MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), - DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), + DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped at the transport layer."), + NICs: tcpip.NICStats{ + UnknownL3ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l3_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L3 protocol."), + UnknownL4ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l4_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L4 protocol."), + MalformedL4RcvdPackets: mustCreateMetric("/netstack/nic/malformed_l4_received_packets", "Number of packets received that failed L4 header parsing."), + Tx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/tx/packets", "Number of packets transmitted."), + Bytes: mustCreateMetric("/netstack/nic/tx/bytes", "Number of bytes transmitted."), + }, + Rx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/rx/packets", "Number of packets received."), + Bytes: mustCreateMetric("/netstack/nic/rx/bytes", "Number of bytes received."), + }, + DisabledRx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/disabled_rx/packets", "Number of packets received on disabled NICs."), + Bytes: mustCreateMetric("/netstack/nic/disabled_rx/bytes", "Number of bytes received on disabled NICs."), + }, + Neighbor: tcpip.NICNeighborStats{ + UnreachableEntryLookups: mustCreateMetric("/netstack/nic/neighbor/unreachable_entry_loopups", "Number of lookups performed on a neighbor entry in Unreachable state."), + }, + }, ICMP: tcpip.ICMPStats{ V4: tcpip.ICMPv4Stats{ PacketsSent: tcpip.ICMPv4SentPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped by netstack due to link layer errors."), - RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped by netstack due to rate limit being exceeded."), + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received."), }, Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Number of ICMPv4 packets received that the transport layer could not parse."), }, @@ -119,40 +137,40 @@ var Metrics = tcpip.Stats{ V6: tcpip.ICMPv6Stats{ PacketsSent: tcpip.ICMPv6SentPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent by netstack."), - MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent by netstack."), - MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), - MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped by netstack due to link layer errors."), - RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped by netstack due to rate limit being exceeded."), + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received by netstack."), - MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received by netstack."), - MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), - MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."), }, Unrecognized: mustCreateMetric("/netstack/icmp/v6/packets_received/unrecognized", "Number of ICMPv6 packets received that the transport layer does not know how to parse."), Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Number of ICMPv6 packets received that the transport layer could not parse."), @@ -163,23 +181,23 @@ var Metrics = tcpip.Stats{ IGMP: tcpip.IGMPStats{ PacketsSent: tcpip.IGMPSentPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent."), }, - Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped due to link layer errors."), }, PacketsReceived: tcpip.IGMPReceivedPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received."), }, - Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received by netstack that could not be parsed."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received that could not be parsed."), ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Number of received IGMP packets with bad checksums."), - Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received by netstack."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received."), }, }, IP: tcpip.IPStats{ @@ -205,7 +223,8 @@ var Metrics = tcpip.Stats{ LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), ExtensionHeaderProblem: mustCreateMetric("/netstack/ip/forwarding/extension_header_problem", "Number of IP packets received which could not be forwarded due to a problem processing their IPv6 extension headers."), - PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not fit within the outgoing MTU."), + PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not be forwarded because they could not fit within the outgoing MTU."), + HostUnreachable: mustCreateMetric("/netstack/ip/forwarding/host_unreachable", "Number of IP packets received which could not be forwarded due to unresolvable next hop."), Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), }, }, @@ -1126,7 +1145,14 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, // TODO(b/64800844): Translate fields once they are added to // tcpip.TCPInfoOption. - info := linux.TCPInfo{} + info := linux.TCPInfo{ + State: uint8(v.State), + RTO: uint32(v.RTO / time.Microsecond), + RTT: uint32(v.RTT / time.Microsecond), + RTTVar: uint32(v.RTTVar / time.Microsecond), + SndSsthresh: v.SndSsthresh, + SndCwnd: v.SndCwnd, + } switch v.CcState { case tcpip.RTORecovery: info.CaState = linux.TCP_CA_Loss @@ -1137,11 +1163,6 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, case tcpip.Open: info.CaState = linux.TCP_CA_Open } - info.RTO = uint32(v.RTO / time.Microsecond) - info.RTT = uint32(v.RTT / time.Microsecond) - info.RTTVar = uint32(v.RTTVar / time.Microsecond) - info.SndSsthresh = v.SndSsthresh - info.SndCwnd = v.SndCwnd // In netstack reorderSeen is updated only when RACK is enabled. // We only track whether the reordering is seen, which is diff --git a/pkg/sentry/socket/netstack/tun.go b/pkg/sentry/socket/netstack/tun.go index 288dd0c9e..c7ed52702 100644 --- a/pkg/sentry/socket/netstack/tun.go +++ b/pkg/sentry/socket/netstack/tun.go @@ -40,7 +40,7 @@ func LinuxToTUNFlags(flags uint16) (tun.Flags, error) { // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) // when there is no sk_filter. See __tun_chr_ioctl() in // net/drivers/tun.c. - if flags&^uint16(linux.IFF_TUN|linux.IFF_TAP|linux.IFF_NO_PI) != 0 { + if flags&^uint16(linux.IFF_TUN|linux.IFF_TAP|linux.IFF_NO_PI|linux.IFF_ONE_QUEUE) != 0 { return tun.Flags{}, syserror.EINVAL } return tun.Flags{ diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 9cd238efd..37443ab78 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -1673,9 +1673,11 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { if err != nil { return err } + c := t.Credentials() hasCap := d.Inode.CheckCapability(t, linux.CAP_CHOWN) isOwner := uattr.Owner.UID == c.EffectiveKUID + var clearPrivilege bool if uid.Ok() { kuid := c.UserNamespace.MapToKUID(uid) // Valid UID must be supplied if UID is to be changed. @@ -1693,6 +1695,11 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { return syserror.EPERM } + // The setuid and setgid bits are cleared during a chown. + if uattr.Owner.UID != kuid { + clearPrivilege = true + } + owner.UID = kuid } if gid.Ok() { @@ -1711,6 +1718,11 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { return syserror.EPERM } + // The setuid and setgid bits are cleared during a chown. + if uattr.Owner.GID != kgid { + clearPrivilege = true + } + owner.GID = kgid } @@ -1721,10 +1733,14 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { if err := d.Inode.SetOwner(t, d, owner); err != nil { return err } + // Clear privilege bits if needed and they are set. + if clearPrivilege && uattr.Perms.HasSetUIDOrGID() && !fs.IsDir(d.Inode.StableAttr) { + uattr.Perms.DropSetUIDAndMaybeGID() + if !d.Inode.SetPermissions(t, d, uattr.Perms) { + return syserror.EPERM + } + } - // When the owner or group are changed by an unprivileged user, - // chown(2) also clears the set-user-ID and set-group-ID bits, but - // we do not support them. return nil } diff --git a/pkg/sync/README.md b/pkg/sync/README.md index 2183c4e20..be1a01f08 100644 --- a/pkg/sync/README.md +++ b/pkg/sync/README.md @@ -1,4 +1,4 @@ -# Syncutil +# sync This package provides additional synchronization primitives not provided by the Go stdlib 'sync' package. It is partially derived from the upstream 'sync' diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 18e6cc3cd..e0dfe5813 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -50,7 +50,7 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { ipv4 := header.IPv4(b) if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") + t.Fatalf("Not a valid IPv4 packet: %x", ipv4) } if !ipv4.IsChecksumValid() { @@ -72,7 +72,7 @@ func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { ipv6 := header.IPv6(b) if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") + t.Fatalf("Not a valid IPv6 packet: %x", ipv6) } for _, f := range checkers { @@ -701,7 +701,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp if !ok { return } - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) foundTS := false tsVal := uint32(0) @@ -748,12 +748,6 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does -// not contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - // TCPSACKBlockChecker creates a checker that verifies that the segment does // contain the specified SACK blocks in the TCP options. func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { @@ -765,7 +759,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } var gotSACKBlocks []header.SACKBlock - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) for i := 0; i < limit; { switch opts[i] { diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index fb819d7a8..9f8f51647 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -29,14 +29,14 @@ type NullClock struct{} var _ tcpip.Clock = (*NullClock)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (*NullClock) NowNanoseconds() int64 { - return 0 +// Now implements tcpip.Clock.Now. +func (*NullClock) Now() time.Time { + return time.Time{} } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (*NullClock) NowMonotonic() int64 { - return 0 +func (*NullClock) NowMonotonic() tcpip.MonotonicTime { + return tcpip.MonotonicTime{} } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -118,16 +118,17 @@ func NewManualClock() *ManualClock { var _ tcpip.Clock = (*ManualClock)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (mc *ManualClock) NowNanoseconds() int64 { +// Now implements tcpip.Clock.Now. +func (mc *ManualClock) Now() time.Time { mc.mu.RLock() defer mc.mu.RUnlock() - return mc.mu.now.UnixNano() + return mc.mu.now } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (mc *ManualClock) NowMonotonic() int64 { - return mc.NowNanoseconds() +func (mc *ManualClock) NowMonotonic() tcpip.MonotonicTime { + var mt tcpip.MonotonicTime + return mt.Add(mc.Now().Sub(time.Unix(0, 0))) } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -218,6 +219,12 @@ func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { } } +// RunImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func (mc *ManualClock) RunImmediatelyScheduledJobs() { + mc.Advance(0) +} + // Advance executes all work that have been scheduled to execute within d from // the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go index c2704df2c..fd2bb470a 100644 --- a/pkg/tcpip/faketime/faketime_test.go +++ b/pkg/tcpip/faketime/faketime_test.go @@ -26,7 +26,7 @@ func TestManualClockAdvance(t *testing.T) { clock := faketime.NewManualClock() start := clock.NowMonotonic() clock.Advance(timeout) - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { + if got, want := clock.NowMonotonic().Sub(start), timeout; got != want { t.Errorf("got = %d, want = %d", got, want) } } @@ -87,7 +87,7 @@ func TestManualClockAfterFunc(t *testing.T) { if got, want := counter2, test.wantCounter2; got != want { t.Errorf("got counter2 = %d, want = %d", got, want) } - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { + if got, want := clock.NowMonotonic().Sub(start), test.advance; got != want { t.Errorf("got elapsed = %d, want = %d", got, want) } }) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 2be21ec75..6a8db84d6 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -813,9 +814,12 @@ const ( // ipv4TimestampTime provides the current time as specified in RFC 791. func ipv4TimestampTime(clock tcpip.Clock) uint32 { - const millisecondsPerDay = 24 * 3600 * 1000 - const nanoPerMilli = 1000000 - return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) + // Per RFC 791 page 21: + // The Timestamp is a right-justified, 32-bit timestamp in + // milliseconds since midnight UT. + now := clock.Now().UTC() + midnight := now.Truncate(24 * time.Hour) + return uint32(now.Sub(midnight).Milliseconds()) } // IP Timestamp option fields. diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 3d1bccd15..d6cad3a94 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -77,12 +77,12 @@ const ( // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag // field in the flags byte within an NDPPrefixInformation. - ndpPrefixInformationOnLinkFlagMask = (1 << 7) + ndpPrefixInformationOnLinkFlagMask = 1 << 7 // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the // Autonomous Address-Configuration flag field in the flags byte within // an NDPPrefixInformation. - ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6 // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 // field in the flags byte within an NDPPrefixInformation. @@ -451,7 +451,7 @@ func (o NDPNonceOption) String() string { // Nonce returns the nonce value this option holds. func (o NDPNonceOption) Nonce() []byte { - return []byte(o) + return o } // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 0df517000..8dabe3354 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -48,6 +48,16 @@ const ( // TCPFlags is the dedicated type for TCP flags. type TCPFlags uint8 +// Intersects returns true iff there are flags common to both f and o. +func (f TCPFlags) Intersects(o TCPFlags) bool { + return f&o != 0 +} + +// Contains returns true iff all the flags in o are contained within f. +func (f TCPFlags) Contains(o TCPFlags) bool { + return f&o == o +} + // String implements Stringer.String. func (f TCPFlags) String() string { flagsStr := []byte("FSRPAU") diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index ef9126deb..f26c857eb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -288,5 +288,5 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index a72eb1aad..6fa1aee18 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -28,13 +28,13 @@ go_test( ":arp", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 94209b026..5fcbfeaa2 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,15 +15,14 @@ package arp_test import ( - "context" "fmt" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -31,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) const ( @@ -39,15 +37,6 @@ const ( stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - - defaultChannelSize = 1 - defaultMTU = 65536 - - // eventChanSize defines the size of event channels used by the neighbor - // cache's event dispatcher. The size chosen here needs to be sufficient to - // queue all the events received during tests before consumption. - // If eventChanSize is too small, the tests may deadlock. - eventChanSize = 32 ) var ( @@ -123,24 +112,6 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo d.C <- e } -func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { - select { - case got := <-d.C: - if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - case <-ctx.Done(): - return fmt.Errorf("%s for %s", ctx.Err(), want) - } - return nil -} - -func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return d.waitForEvent(ctx, want) -} - func (d *arpDispatcher) nextEvent() (eventInfo, bool) { select { case event := <-d.C: @@ -153,55 +124,45 @@ func (d *arpDispatcher) nextEvent() (eventInfo, bool) { type testContext struct { s *stack.Stack linkEP *channel.Endpoint - nudDisp *arpDispatcher + nudDisp arpDispatcher } -func newTestContext(t *testing.T) *testContext { - c := stack.DefaultNUDConfigurations() - // Transition from Reachable to Stale almost immediately to test if receiving - // probes refreshes positive reachability. - c.BaseReachableTime = time.Microsecond - - d := arpDispatcher{ - // Create an event channel large enough so the neighbor cache doesn't block - // while dispatching events. Blocking could interfere with the timing of - // NUD transitions. - C: make(chan eventInfo, eventChanSize), +func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext { + t.Helper() + + tc := testContext{ + nudDisp: arpDispatcher{ + C: make(chan eventInfo, eventDepth), + }, } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - NUDConfigs: c, - NUDDisp: &d, + tc.s = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + NUDDisp: &tc.nudDisp, + Clock: &faketime.NullClock{}, }) - ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - wep := stack.LinkEndpoint(ep) + tc.linkEP = channel.New(packetDepth, header.IPv4MinimumMTU, stackLinkAddr) + tc.linkEP.LinkEPCapabilities |= stack.CapabilityResolutionRequired + wep := stack.LinkEndpoint(tc.linkEP) if testing.Verbose() { - wep = sniffer.New(ep) + wep = sniffer.New(wep) } - if err := s.CreateNIC(nicID, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + if err := tc.s.CreateNIC(nicID, wep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) + if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %s", err) } - s.SetRouteTable([]tcpip.Route{{ + tc.s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: nicID, }}) - return &testContext{ - s: s, - linkEP: ep, - nudDisp: &d, - } + return tc } func (c *testContext) cleanup() { @@ -209,7 +170,7 @@ func (c *testContext) cleanup() { } func TestMalformedPacket(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() v := make(buffer.View, header.ARPSize) @@ -228,7 +189,7 @@ func TestMalformedPacket(t *testing.T) { } func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) @@ -253,7 +214,7 @@ func TestDisabledEndpoint(t *testing.T) { } func TestDirectReply(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -284,7 +245,7 @@ func TestDirectReply(t *testing.T) { } func TestDirectRequest(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 1, 1) defer c.cleanup() tests := []struct { @@ -391,17 +352,21 @@ func TestDirectRequest(t *testing.T) { } // Verify the sender was saved in the neighbor cache. - wantEvent := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: stack.NeighborEntry{ - Addr: test.senderAddr, - LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), - State: stack.Stale, - }, - } - if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { - t.Fatal(err) + if got, ok := c.nudDisp.nextEvent(); ok { + want := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: test.senderLinkAddr, + State: stack.Stale, + }, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + t.Errorf("got invalid event (-want +got):\n%s", diff) + } + } else { + t.Fatal("event didn't arrive") } neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) @@ -589,7 +554,7 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, }) - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + linkEP := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -663,15 +628,16 @@ func TestLinkAddressRequest(t *testing.T) { } func TestDADARPRequestPacket(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, }, }), ipv4.NewProtocol}, + Clock: clock, }) - e := channel.New(1, defaultMTU, stackLinkAddr) + e := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -682,7 +648,8 @@ func TestDADARPRequestPacket(t *testing.T) { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting) } - pkt, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + pkt, ok := e.Read() if !ok { t.Fatal("expected to send an ARP request") } diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go index 5168f5361..1ba4d0d36 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go @@ -251,7 +251,7 @@ func (f *Fragmentation) releaseReassemblersLocked() { // The list is empty. break } - elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + elapsed := now.Sub(r.createdAt) if f.timeout > elapsed { // If the oldest reassembler has not expired, schedule the release // job so that this function is called back when it has expired. diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 7daf64b4a..dadfc28cc 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -275,15 +275,23 @@ func TestMemoryLimits(t *testing.T) { highLimit := 3 * lowLimit // Allow at most 3 such packets. f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) + if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) + if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) + if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { + t.Fatal(err) + } if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -300,9 +308,13 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { memSize := pkt(1, "0").MemSize() f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 56b76a284..5b7e4b361 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -35,21 +35,21 @@ type hole struct { type reassembler struct { reassemblerEntry - id FragmentID - memSize int - proto uint8 - mu sync.Mutex - holes []hole - filled int - done bool - creationTime int64 - pkt *stack.PacketBuffer + id FragmentID + memSize int + proto uint8 + mu sync.Mutex + holes []hole + filled int + done bool + createdAt tcpip.MonotonicTime + pkt *stack.PacketBuffer } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ - id: id, - creationTime: clock.NowMonotonic(), + id: id, + createdAt: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index a22b712c6..24687cf06 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -133,7 +133,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) } // Wait for any initially fired timers to complete. - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -147,7 +147,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -156,7 +156,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -170,7 +170,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -208,7 +208,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -247,7 +247,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -272,7 +272,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -347,7 +347,7 @@ func TestNonce(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() for i, want := range test.expectedResults { if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index d22974b12..671dfbf32 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -611,7 +611,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // If a timer for any address is already running, it is reset to the new // random value only if the requested Maximum Response Delay is less than // the remaining value of the running timer. - now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds()) + now := g.opts.Clock.Now() if info.state == delayingMember { if info.delayedReportJobFiresAt.IsZero() { panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress)) diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 0c2b62127..40ab21cb6 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -42,6 +42,10 @@ type MultiCounterIPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig tcpip.MultiCounterStat + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable tcpip.MultiCounterStat + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -61,6 +65,7 @@ func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig) m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) + m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable) } // LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index cec3e62c4..a180e5c75 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip/network/internal/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", + "//pkg/tcpip/tests/integration:__pkg__", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index d1a82b584..5f6b0c6af 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -222,7 +222,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -481,6 +480,22 @@ func (*icmpReasonFragmentationNeeded) isForwarding() bool { return true } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 792 page 5, Destination Unreachable Message, + // + // In addition, in some networks, the gateway may be able to determine + // if the internet destination host is unreachable. Gateways in these + // networks may send destination unreachable messages to the source host + // when the destination host is unreachable. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -537,7 +552,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -653,6 +673,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4NetUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4HostUnreachable) + counter = sent.dstUnreachable case *icmpReasonFragmentationNeeded: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 23178277a..bb8d53c12 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -104,6 +104,16 @@ type endpoint struct { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, return an ICMP error to the original + // packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -898,7 +908,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) case *ip.ErrNoRoute: stats.ip.Forwarding.Unrouteable.Increment() case *ip.ErrParameterProblem: - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() case *ip.ErrMessageTooLong: stats.ip.Forwarding.PacketTooBig.Increment() @@ -935,7 +944,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -1008,7 +1016,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() } return diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index da9cc0ae8..0a1f02de0 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -16,7 +16,6 @@ package ipv4_test import ( "bytes" - "context" "encoding/hex" "fmt" "io/ioutil" @@ -36,10 +35,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" - tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -112,10 +111,6 @@ func TestExcludeBroadcast(t *testing.T) { }) } -type forwardedPacket struct { - fragments []fragmentInfo -} - func TestForwarding(t *testing.T) { const ( incomingNICID = 1 @@ -134,11 +129,11 @@ func TestForwarding(t *testing.T) { PrefixLen: 8, } outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2") - remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2") - unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2") - multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0") - linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0") + remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2") + remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2") + unreachableIPv4Addr := testutil.MustParse4("12.0.0.2") + multicastIPv4Addr := testutil.MustParse4("225.0.0.0") + linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0") tests := []struct { name string @@ -402,13 +397,13 @@ func TestForwarding(t *testing.T) { totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) - icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) + icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ TotalLength: uint16(totalLength), @@ -453,7 +448,7 @@ func TestForwarding(t *testing.T) { return len(hdr.View()) } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(incomingIPv4Addr.Address), checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), @@ -461,7 +456,7 @@ func TestForwarding(t *testing.T) { checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), + checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) } else if ok { @@ -470,7 +465,7 @@ func TestForwarding(t *testing.T) { if test.expectPacketForwarded { if len(test.expectedFragmentsForwarded) != 0 { - fragmentedPackets := []*stack.PacketBuffer{} + var fragmentedPackets []*stack.PacketBuffer for i := 0; i < len(test.expectedFragmentsForwarded); i++ { reply, ok = outgoingEndpoint.Read() if !ok { @@ -487,7 +482,7 @@ func TestForwarding(t *testing.T) { // maximum IP header size and the maximum size allocated for link layer // headers. In this case, no size is allocated for link layer headers. expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize - if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { + if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { t.Error(err) } } else { @@ -496,7 +491,7 @@ func TestForwarding(t *testing.T) { t.Fatal("expected ICMP Echo packet through outgoing NIC") } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), @@ -1212,15 +1207,15 @@ func TestIPv4Sanity(t *testing.T) { } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) // Specify ident/seq to make sure we get the same in the response. - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength @@ -1315,7 +1310,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1334,7 +1329,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4( checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1546,9 +1541,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -1602,7 +1597,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) for _, test := range writePacketsTests { t.Run(test.description, func(t *testing.T) { @@ -1612,13 +1607,13 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -1726,8 +1721,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2301,7 +2296,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv4Type(header.ICMPv4TimeExceeded), checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), checker.ICMPv4Checksum(), - checker.ICMPv4Payload([]byte(firstFragmentSent)), + checker.ICMPv4Payload(firstFragmentSent), ), ) }) @@ -2705,7 +2700,7 @@ func TestReceiveFragments(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, RawFactory: raw.EndpointFactory{}, }) - e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) + e := channel.New(0, 1280, "\xf0\x00") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -2948,7 +2943,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -3074,7 +3069,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ @@ -3090,7 +3085,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3133,7 +3128,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3157,9 +3152,11 @@ func TestPacketQueing(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -3182,7 +3179,8 @@ func TestPacketQueing(t *testing.T) { // Wait for a ARP request since link address resolution should be // performed. { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3223,6 +3221,7 @@ func TestPacketQueing(t *testing.T) { } // Expect the response now that the link address has resolved. + clock.RunImmediatelyScheduledJobs() test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed @@ -3244,8 +3243,8 @@ func TestCloseLocking(t *testing.T) { ) var ( - src = tcptestutil.MustParse4("16.0.0.1") - dst = tcptestutil.MustParse4("16.0.0.2") + src = testutil.MustParse4("16.0.0.1") + dst = testutil.MustParse4("16.0.0.2") ) s := stack.New(stack.Options{ @@ -3253,7 +3252,7 @@ func TestCloseLocking(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - // Perform NAT so that the endoint tries to search for a sibling endpoint + // Perform NAT so that the endpoint tries to search for a sibling endpoint // which ends up taking the protocol and endpoint lock (in that order). table := stack.Table{ Rules: []stack.Rule{ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 307e1972d..23fc94303 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -1029,6 +1029,26 @@ func (*icmpReasonNetUnreachable) respondsToMulticast() bool { return false } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 4443 page 8, Destination Unreachable Message, + // + // If the reason for the failure to deliver cannot be mapped to any of + // other codes, the Code field is set to 3. Example of such cases are + // an inability to resolve the IPv6 destination address into a + // corresponding link address, or a link-specific problem of some sort. + return true +} + +func (*icmpReasonHostUnreachable) respondsToMulticast() bool { + return false +} + // icmpReasonFragmentationNeeded is an error where a packet is to big to be sent // out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message. type icmpReasonPacketTooBig struct{} @@ -1143,7 +1163,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -1222,6 +1247,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6AddressUnreachable) + counter = sent.dstUnreachable case *icmpReasonPacketTooBig: icmpHdr.SetType(header.ICMPv6PacketTooBig) icmpHdr.SetCode(header.ICMPv6UnusedCode) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 040cd4bc8..c2e9544c1 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -16,17 +16,16 @@ package ipv6 import ( "bytes" - "context" "net" "reflect" "strings" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -46,16 +45,12 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 30 * time.Second - arbitraryHopLimit = 42 ) var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) - lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -371,6 +366,8 @@ type testContext struct { linkEP0 *channel.Endpoint linkEP1 *channel.Endpoint + + clock *faketime.ManualClock } type endpointWithResolutionCapability struct { @@ -382,15 +379,19 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab } func newTestContext(t *testing.T) *testContext { + clock := faketime.NewManualClock() c := &testContext{ s0: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + Clock: clock, }), s1: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + Clock: clock, }), + clock: clock, } c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) @@ -457,10 +458,14 @@ type routeArgs struct { remoteLinkAddr tcpip.LinkAddress } -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { +func routeICMPv6Packet(t *testing.T, clock *faketime.ManualClock, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pi, _ := args.src.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + pi, ok := args.src.Read() + if !ok { + t.Fatal("packet didn't arrive") + } { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -533,7 +538,7 @@ func TestLinkResolution(t *testing.T) { {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { + routeICMPv6Packet(t, c.clock, args, func(t *testing.T, icmpv6 header.ICMPv6) { if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) } @@ -544,7 +549,7 @@ func TestLinkResolution(t *testing.T) { {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, } { - routeICMPv6Packet(t, args, nil) + routeICMPv6Packet(t, c.clock, args, nil) } } @@ -1309,7 +1314,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1325,7 +1330,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1371,7 +1376,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1396,9 +1401,11 @@ func TestPacketQueing(t *testing.T) { e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -1421,7 +1428,8 @@ func TestPacketQueing(t *testing.T) { // Wait for a neighbor solicitation since link address resolution should // be performed. { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1475,6 +1483,7 @@ func TestPacketQueing(t *testing.T) { } // Expect the response now that the link address has resolved. + clock.RunImmediatelyScheduledJobs() test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 95e11ac51..a9de6020f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -184,7 +184,6 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol - stack *stack.Stack stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is @@ -282,6 +281,16 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, we should return an ICMP error to the + // original packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -666,8 +675,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params if length > math.MaxUint16 { return &tcpip.ErrMessageTooLong{} } - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) - ip.Encode(&header.IPv6Fields{ + header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)).Encode(&header.IPv6Fields{ PayloadLength: uint16(length), TransportProtocol: params.Protocol, HopLimit: params.TTL, @@ -909,20 +917,20 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return &tcpip.ErrMalformedHeader{} } - ip := header.IPv6(h) + ipH := header.IPv6(h) // Always set the payload length. pktSize := pkt.Data().Size() - ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + ipH.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) // Set the source address when zero. - if ip.SourceAddress() == header.IPv6Any { - ip.SetSourceAddress(r.LocalAddress()) + if ipH.SourceAddress() == header.IPv6Any { + ipH.SetSourceAddress(r.LocalAddress()) } // Set the destination. If the packet already included a destination, it will // be part of the route anyways. - ip.SetDestinationAddress(r.RemoteAddress()) + ipH.SetDestinationAddress(r.RemoteAddress()) // Populate the packet buffer's network header and don't allow an invalid // packet to be sent. @@ -2180,7 +2188,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error return 0, &tcpip.ErrMalformedHeader{} } - networkMTU := linkMTU - uint32(networkHeadersLen) + networkMTU := linkMTU - networkHeadersLen if networkMTU > maxPayloadSize { networkMTU = maxPayloadSize } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 71d1c3e28..bc9cf6999 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -16,6 +16,7 @@ package ipv6_test import ( "bytes" + "math/rand" "testing" "time" @@ -138,7 +139,8 @@ func TestSendQueuedMLDReports(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - SecureRNG: &secureRNG, + SecureRNG: &secureRNG, + RandSource: rand.NewSource(time.Now().UnixNano()), NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: test.dadTransmits, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index f0ff111c5..851cd6e75 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -16,7 +16,6 @@ package ipv6 import ( "fmt" - "math/rand" "time" "gvisor.dev/gvisor/pkg/sync" @@ -549,7 +548,7 @@ type tempSLAACAddrState struct { // Must not be nil. regenJob *tcpip.Job - createdAt time.Time + createdAt tcpip.MonotonicTime // The address's endpoint. // @@ -573,10 +572,10 @@ type slaacPrefixState struct { invalidationJob *tcpip.Job // Nonzero only when the address is not valid forever. - validUntil time.Time + validUntil tcpip.MonotonicTime // Nonzero only when the address is not preferred forever. - preferredUntil time.Time + preferredUntil tcpip.MonotonicTime // State associated with the stable address generated for the prefix. stableAddr struct { @@ -1051,7 +1050,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, } - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // The time an address is preferred until is needed to properly generate the // address. @@ -1182,7 +1181,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt state.stableAddr.localGenerationFailures++ } - if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, ndp.ep.protocol.stack.Clock().NowMonotonic().Sub(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { state.stableAddr.addressEndpoint = addressEndpoint state.generationAttempts++ return true @@ -1237,13 +1236,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary // address is the lower of the valid lifetime of the stable address or the // maximum temporary address valid lifetime. vl := ndp.configs.MaxTempAddrValidLifetime - if prefixState.validUntil != (time.Time{}) { + if prefixState.validUntil != (tcpip.MonotonicTime{}) { if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { vl = prefixVL } @@ -1259,7 +1258,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // maximum temporary address preferred lifetime - the temporary address desync // factor. pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor - if prefixState.preferredUntil != (time.Time{}) { + if prefixState.preferredUntil != (tcpip.MonotonicTime{}) { if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { // Respect the preferred lifetime of the prefix, as per RFC 4941 section // 3.3 step 4. @@ -1394,7 +1393,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // deprecation job so it can be reset. prefixState.deprecationJob.Cancel() - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { @@ -1403,7 +1402,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } prefixState.preferredUntil = now.Add(pl) } else { - prefixState.preferredUntil = time.Time{} + prefixState.preferredUntil = tcpip.MonotonicTime{} } // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: @@ -1421,17 +1420,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // Handle the infinite valid lifetime separately as we do not schedule a // job in this case. prefixState.invalidationJob.Cancel() - prefixState.validUntil = time.Time{} + prefixState.validUntil = tcpip.MonotonicTime{} } else { var effectiveVl time.Duration var rl time.Duration // If the prefix was originally set to be valid forever, assume the // remaining time to be the maximum possible value. - if prefixState.validUntil == (time.Time{}) { + if prefixState.validUntil == (tcpip.MonotonicTime{}) { rl = header.NDPInfiniteLifetime } else { - rl = time.Until(prefixState.validUntil) + rl = prefixState.validUntil.Sub(now) } if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { @@ -1463,7 +1462,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // maximum temporary address valid lifetime. Note, the valid lifetime of a // temporary address is relative to the address's creation time. validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) - if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { + if prefixState.validUntil != (tcpip.MonotonicTime{}) && validUntil.Sub(prefixState.validUntil) > 0 { validUntil = prefixState.validUntil } @@ -1483,7 +1482,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // desync factor. Note, the preferred lifetime of a temporary address is // relative to the address's creation time. preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) - if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { + if prefixState.preferredUntil != (tcpip.MonotonicTime{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { preferredUntil = prefixState.preferredUntil } @@ -1718,7 +1717,7 @@ func (ndp *ndpState) startSolicitingRouters() { // 4861 section 6.3.7. var delay time.Duration if ndp.configs.MaxRtrSolicitationDelay > 0 { - delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) + delay = time.Duration(ndp.ep.protocol.stack.Rand().Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } // Protected by ndp.ep.mu. @@ -1754,7 +1753,7 @@ func (ndp *ndpState) startSolicitingRouters() { header.NDPSourceLinkLayerAddressOption(linkAddress), } } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + optsSerializer.Length() icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) rs := header.NDPRouterSolicit(icmpData.MessageBody()) @@ -1861,7 +1860,7 @@ func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { - ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) + ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor))) } } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 234e34952..2c2416328 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -16,7 +16,7 @@ package ipv6 import ( "bytes" - "context" + "math/rand" "strings" "testing" "time" @@ -32,58 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - t.Cleanup(ep.Close) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := llladdr.WithPrefix() - if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - addressEP.DecRef() - } - - return s, ep -} - var _ NDPDispatcher = (*testNDPDispatcher)(nil) // testNDPDispatcher is an NDPDispatcher only allows default router discovery. @@ -163,11 +111,6 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } -type linkResolutionResult struct { - linkAddr tcpip.LinkAddress - ok bool -} - // TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -456,8 +399,10 @@ func TestNeighborSolicitationResponse(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + Clock: clock, }) e := channel.New(1, 1280, nicLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -527,7 +472,8 @@ func TestNeighborSolicitationResponse(t *testing.T) { } if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, got := e.Read() if !got { t.Fatal("expected an NDP NS response") } @@ -582,7 +528,8 @@ func TestNeighborSolicitationResponse(t *testing.T) { })) } - p, got := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, got := e.Read() if !got { t.Fatal("expected an NDP NA response") } @@ -906,12 +853,12 @@ func TestNDPValidation(t *testing.T) { routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp[:typ.size], + icmpH := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmpH[typ.size:], typ.extraData) + icmpH.SetType(typ.typ) + icmpH.SetCode(test.code) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH[:typ.size], Src: lladdr0, Dst: lladdr1, PayloadCsum: header.Checksum(typ.extraData /* initial */, 0), @@ -937,7 +884,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) + handleIPv6Payload(buffer.View(icmpH), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { @@ -1297,8 +1244,9 @@ func TestCheckDuplicateAddress(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - SecureRNG: &secureRNG, - Clock: clock, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ DADConfigs: dadConfigs, })}, @@ -1315,7 +1263,8 @@ func TestCheckDuplicateAddress(t *testing.T) { snmc := header.SolicitedNodeAddr(lladdr0) remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) checkDADMsg := func() { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("expected %d-th DAD message", dadPacketsSent) } diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index b5b013b64..854d6a6ba 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -101,7 +101,7 @@ func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) { // Wildcard destinations affect all destinations for TupleOnly. if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress { // Only bitwise and the TupleOnlyFlag. - intersection &= ((^TupleOnlyFlag) | counter.SharedFlags()) + intersection &= (^TupleOnlyFlag) | counter.SharedFlags() count++ } } @@ -238,13 +238,13 @@ type PortTester func(port uint16) (good bool, err tcpip.Error) // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) { +func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) { pm.ephemeralMu.RLock() firstEphemeral := pm.firstEphemeral numEphemeral := pm.numEphemeral pm.ephemeralMu.RUnlock() - offset := uint32(rand.Int31n(int32(numEphemeral))) + offset := uint32(rng.Int31n(int32(numEphemeral))) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } @@ -303,7 +303,7 @@ func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) // An optional PortTester can be passed in which if provided will be used to // test if the picked port can be used. The function should return true if the // port is safe to use, false otherwise. -func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { +func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -328,7 +328,7 @@ func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reserv } // A port wasn't specified, so try to find one. - return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) { res.Port = p if !pm.reserveSpecificPortLocked(res) { return false, nil diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 6c4fb8c68..a91b130df 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -18,6 +18,7 @@ import ( "math" "math/rand" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -331,6 +332,7 @@ func TestPortReservation(t *testing.T) { t.Run(test.tname, func(t *testing.T) { pm := NewPortManager() net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} + rng := rand.New(rand.NewSource(time.Now().UnixNano())) for _, test := range test.actions { first, _ := pm.PortRange() @@ -356,7 +358,7 @@ func TestPortReservation(t *testing.T) { BindToDevice: test.device, Dest: test.dest, } - gotPort, err := pm.ReservePort(portRes, nil /* testPort */) + gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */) if diff := cmp.Diff(test.want, err); diff != "" { t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff) } @@ -417,10 +419,11 @@ func TestPickEphemeralPort(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { t.Fatalf("failed to set ephemeral port range: %s", err) } - port, err := pm.PickEphemeralPort(test.f) + port, err := pm.PickEphemeralPort(rng, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 84aa6a9e4..395ff9a07 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5720e7543..f7fbcbaa7 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -125,6 +125,8 @@ type conn struct { tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. + // + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. lastUsed time.Time `state:".(unixTime)"` } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7107d598d..72f66441f 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -114,10 +115,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -134,7 +131,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -224,7 +221,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { - time.AfterFunc(f.proto.addrResolveDelay, func() { + f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() { fn(f.proto.neigh, addr, remoteLinkAddr) }) } @@ -319,7 +316,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -354,17 +351,19 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) { + clock := faketime.NewManualClock() // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { proto.stack = s return proto }}, + Clock: clock, }) protoNum := proto.Number() @@ -373,7 +372,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 1 has the link address "a", and added the network address 1. - ep1 = &fwdTestLinkEndpoint{ + ep1 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "a", @@ -386,7 +385,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 2 has the link address "b", and added the network address 2. - ep2 = &fwdTestLinkEndpoint{ + ep2 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "b", @@ -416,7 +415,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) } - return ep1, ep2 + return clock, ep1, ep2 } func TestForwardingWithStaticResolver(t *testing.T) { @@ -432,7 +431,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -444,6 +443,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: default: @@ -475,7 +475,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -487,9 +487,10 @@ func TestForwardingWithFakeResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -508,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // Whether or not we use the neighbor cache here does not matter since // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -518,10 +519,11 @@ func TestForwardingWithNoResolver(t *testing.T) { Data: buf.ToVectorisedView(), })) + clock.Advance(proto.addrResolveDelay) select { case <-ep2.C: t.Fatal("Packet should not be forwarded") - case <-time.After(time.Second): + default: } } @@ -533,7 +535,7 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) const numPackets int = 5 // These packets will all be enqueued in the packet queue to wait for link @@ -547,12 +549,12 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { } // All packets should fail resolution. - // TODO(gvisor.dev/issue/5141): Use a fake clock. for i := 0; i < numPackets; i++ { + clock.Advance(proto.addrResolveDelay) select { case got := <-ep2.C: t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) - case <-time.After(100 * time.Millisecond): + default: } } } @@ -576,7 +578,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { } }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -596,9 +598,10 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -631,7 +634,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -645,9 +648,10 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -681,7 +685,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -697,9 +701,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -745,7 +750,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. @@ -761,9 +766,10 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { for i := 0; i < maxPendingResolutions; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 3670d5995..0a26f6dd8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() *IPTables { +func DefaultTables(seed uint32) *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ NATID: { @@ -182,7 +182,7 @@ func DefaultTables() *IPTables { Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ - seed: generateRandUint32(), + seed: seed, }, reaperDone: make(chan struct{}, 1), } @@ -371,6 +371,7 @@ func (it *IPTables) startReaper(interval time.Duration) { select { case <-it.reaperDone: return + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. case <-time.After(interval): bucket, interval = it.connections.reapUnused(bucket, interval) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index ac2fa777e..133bacdd0 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -16,14 +16,14 @@ package stack_test import ( "bytes" - "context" "encoding/binary" "fmt" + "math/rand" "testing" "time" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -497,8 +497,9 @@ func TestDADResolve(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ - Clock: clock, - SecureRNG: &secureRNG, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: stack.DADConfigurations{ @@ -605,7 +606,10 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, _ := e.ReadContext(context.Background()) + p, ok := e.Read() + if !ok { + t.Fatal("packet didn't arrive") + } // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -731,11 +735,13 @@ func TestDADFail(t *testing.T) { dadConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -761,16 +767,17 @@ func TestDADFail(t *testing.T) { // Wait for DAD to fail and make sure the address did // not get resolved. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) @@ -839,11 +846,13 @@ func TestDADStop(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -861,15 +870,16 @@ func TestDADStop(t *testing.T) { test.stopFn(t, s) // Wait for DAD to fail (since the address was removed during DAD). + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") } if !test.skipFinalAddrCheck { @@ -920,10 +930,12 @@ func TestSetNDPConfigurations(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, })}, + Clock: clock, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -1002,28 +1014,23 @@ func TestSetNDPConfigurations(t *testing.T) { t.Fatal(err) } - // Sleep until right (500ms before) before resolution to - // make sure the address didn't resolve on NIC(1) yet. - const delta = 500 * time.Millisecond - time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + // Sleep until right before resolution to make sure the address didn't + // resolve on NIC(1) yet. + const delta = 1 + clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { t.Fatal(err) @@ -1035,7 +1042,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) @@ -1048,13 +1055,13 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo if managedAddress { // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) // of the RA payload. - raPayload[1] |= (1 << 7) + raPayload[1] |= 1 << 7 } // Populate the Other Configurations flag field. if otherConfigurations { // The Other Configurations flag field is the 6th bit of byte #1 // (0-indexing) of the RA payload. - raPayload[1] |= (1 << 6) + raPayload[1] |= 1 << 6 } opts := ra.Options() opts.Serialize(optSer) @@ -1340,6 +1347,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { routerC: make(chan ndpRouterEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1348,6 +1356,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1369,10 +1378,11 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { // Wait for the invalidation time plus some buffer to make sure we do // not actually receive any invalidation events as we should not have // remembered the router in the first place. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.routerC: t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } } @@ -1383,6 +1393,7 @@ func TestRouterDiscovery(t *testing.T) { rememberRouter: true, } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1391,6 +1402,7 @@ func TestRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) expectRouterEvent := func(addr tcpip.Address, discovered bool) { @@ -1409,12 +1421,13 @@ func TestRouterDiscovery(t *testing.T) { expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.routerC: if diff := checkRouterEvent(e, addr, false); diff != "" { t.Errorf("router event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for router discovery event") } } @@ -1461,7 +1474,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) @@ -1478,7 +1491,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) }) } @@ -1547,6 +1560,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1555,6 +1569,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1576,10 +1591,11 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { // Wait for the invalidation time plus some buffer to make sure we do // not actually receive any invalidation events as we should not have // remembered the prefix in the first place. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.prefixC: t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } } @@ -1594,6 +1610,7 @@ func TestPrefixDiscovery(t *testing.T) { rememberPrefix: true, } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1602,6 +1619,7 @@ func TestPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1662,12 +1680,13 @@ func TestPrefixDiscovery(t *testing.T) { // Wait for prefix2's most recent invalidation job plus some buffer to // expire. + clock.Advance(time.Duration(lifetime) * time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for prefix discovery event") } @@ -1700,6 +1719,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { rememberPrefix: true, } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1708,6 +1728,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1731,21 +1752,23 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // with infinite valid lifetime which should not get invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) expectPrefixEvent(subnet, true) + clock.Advance(testInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with finite lifetime. // The prefix should get invalidated after 1s. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + clock.Advance(testInfiniteLifetime) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(testInfiniteLifetime): + default: t.Fatal("timed out waiting for prefix discovery event") } @@ -1756,19 +1779,21 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + clock.Advance(testInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with a prefix with a lifetime value greater than the // set infinite lifetime value. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + clock.Advance((testInfiniteLifetimeSeconds + 1) * time.Second) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with 0 lifetime. @@ -3613,6 +3638,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3621,6 +3647,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3654,10 +3681,11 @@ func TestAutoGenAddrRemoval(t *testing.T) { // Wait for the original valid lifetime to make sure the original job got // cancelled/cleaned up. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } } @@ -3779,6 +3807,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3787,6 +3816,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3816,30 +3846,36 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // Should not get an invalidation event after the PI's invalidation // time. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } } +func makeSecretKey(t *testing.T) []byte { + secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes) + n, err := cryptorand.Read(secretKey) + if err != nil { + t.Fatalf("cryptorand.Read(_): %s", err) + } + if l := len(secretKey); n != l { + t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l) + } + return secretKey +} + // TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use // opaque interface identifiers when configured to do so. func TestAutoGenAddrWithOpaqueIID(t *testing.T) { const nicID = 1 const nicName = "nic1" - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + + secretKey := makeSecretKey(t) prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) @@ -3861,6 +3897,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3875,6 +3912,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3913,12 +3951,13 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(validLifetimeSecondPrefix1 * time.Second) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3944,15 +3983,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }() ipv6.MaxDesyncFactor = time.Nanosecond - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4231,13 +4262,12 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrType := addrType t.Run(addrType.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4248,6 +4278,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { RetransmitTimer: retransmitTimer, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4292,7 +4323,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4309,15 +4340,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { const maxRetries = 1 const lifetimeSeconds = 5 - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4326,6 +4349,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -4345,6 +4369,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4375,7 +4400,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict after some time has passed. - time.Sleep(failureTimer) + clock.Advance(failureTimer) rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { @@ -4390,12 +4415,13 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Let the next address resolve. addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) expectAutoGenAddrEvent(addr, newAddr) + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } @@ -4409,6 +4435,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // // We expect either just the invalidation event or the deprecation event // followed by the invalidation event. + clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer) select { case e := <-ndpDisp.autoGenAddrC: if e.eventType == deprecatedAddr { @@ -4421,7 +4448,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") } } else { @@ -4429,7 +4456,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for auto gen addr event") } } @@ -4863,6 +4890,7 @@ func TestCleanupNDPState(t *testing.T) { rememberPrefix: true, autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, @@ -4874,6 +4902,7 @@ func TestCleanupNDPState(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) expectRouterEvent := func() (bool, ndpRouterEvent) { @@ -5106,7 +5135,7 @@ func TestCleanupNDPState(t *testing.T) { // Should not get any more events (invalidation timers should have been // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.routerC: t.Error("unexpected router event") @@ -5236,6 +5265,23 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { expectNoDHCPv6Event() } +var _ rand.Source = (*savingRandSource)(nil) + +type savingRandSource struct { + s rand.Source + + lastInt63 int64 +} + +func (d *savingRandSource) Int63() int64 { + i := d.s.Int63() + d.lastInt63 = i + return i +} +func (d *savingRandSource) Seed(seed int64) { + d.s.Seed(seed) +} + // TestRouterSolicitation tests the initial Router Solicitations that are sent // when a NIC newly becomes enabled. func TestRouterSolicitation(t *testing.T) { @@ -5402,6 +5448,9 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("unexpectedly got a packet = %#v", p) } } + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), + } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5411,8 +5460,10 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, })}, - Clock: clock, + Clock: clock, + RandSource: &randSource, }) + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -5425,19 +5476,27 @@ func TestRouterSolicitation(t *testing.T) { // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) + if remaining != 0 { + maxRtrSolicitDelay := test.maxRtrSolicitDelay + if maxRtrSolicitDelay < 0 { + maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay + } + var actualRtrSolicitDelay time.Duration + if maxRtrSolicitDelay != 0 { + actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay + } + waitForPkt(actualRtrSolicitDelay) remaining-- } subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + for ; remaining != 0; remaining-- { + if test.effectiveRtrSolicitInt != 0 { waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) waitForPkt(time.Nanosecond) } else { - waitForPkt(test.effectiveRtrSolicitInt) + waitForPkt(0) } } @@ -5533,12 +5592,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { + waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) + clock.Advance(timeout) + p, ok := e.Read() if !ok { t.Fatal("timed out waiting for packet") } @@ -5552,6 +5610,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS()) } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5561,6 +5620,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { MaxRtrSolicitationDelay: delay, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5568,13 +5628,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok = e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok = e.Read(); ok { t.Fatal("should not have sent more than one RS message") } } @@ -5582,9 +5640,8 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") } @@ -5595,21 +5652,19 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Start soliciting routers. test.startFn(t, s) - waitForPkt(delay + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + waitForPkt(clock, delay) + waitForPkt(clock, interval) + waitForPkt(clock, interval) + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") } // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") } }) diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 509f5ce5c..08857e1a9 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -310,7 +310,7 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { func (n *neighborCache) init(nic *nic, r LinkAddressResolver) { *n = neighborCache{ nic: nic, - state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator), + state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.randomGenerator), linkRes: r, } n.mu.Lock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 9821a18d3..7de25fe37 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -301,10 +293,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }) } @@ -313,10 +305,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }) @@ -347,10 +339,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -419,10 +411,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -500,12 +492,12 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now().Add(-durationReachableNanos), } wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) } @@ -571,10 +563,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -616,10 +608,10 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -663,10 +655,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -689,10 +681,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -733,10 +725,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -758,10 +750,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -814,20 +806,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -844,10 +836,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -875,10 +867,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -890,10 +882,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -910,10 +902,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -947,10 +939,10 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -973,20 +965,20 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -1027,10 +1019,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -1062,13 +1054,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - startedAt := clock.NowNanoseconds() + startedAt := clock.Now() // The following logic is very similar to overflowCache, but // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1118,7 +1110,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { State: Reachable, // Can be inferred since the frequently used entry is the first to // be created and transitioned to Reachable. - UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(), + UpdatedAt: startedAt.Add(typicalLatency), }, } @@ -1127,12 +1119,12 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1190,12 +1182,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { if !ok { t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1244,10 +1236,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1263,10 +1255,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1301,10 +1293,10 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1405,10 +1397,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1436,10 +1428,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -1455,10 +1447,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { { wantEntries := []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, } if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { @@ -1488,10 +1480,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1518,10 +1510,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1541,10 +1533,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) @@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) { linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..0a59eecdd 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -31,10 +31,10 @@ const ( // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAtNanos int64 + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -138,10 +138,10 @@ func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState * // calling `setStateLocked`. func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: cache.nic.stack.clock.Now(), } n := &neighborEntry{ cache: cache, @@ -166,14 +166,20 @@ func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) { if ch := e.mu.done; ch != nil { close(ch) e.mu.done = nil - // Dequeue the pending packets in a new goroutine to not hold up the current + // Dequeue the pending packets asynchronously to not hold up the current // goroutine as writing packets may be a costly operation. // // At the time of writing, when writing packets, a neighbor's link address // is resolved (which ends up obtaining the entry's lock) while holding the - // link resolution queue's lock. Dequeuing packets in a new goroutine avoids - // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + // link resolution queue's lock. Dequeuing packets asynchronously avoids a + // lock ordering violation. + // + // NB: this is equivalent to spawning a goroutine directly using the go + // keyword but allows tests that use manual clocks to deterministically + // wait for this work to complete. + e.cache.nic.stack.clock.AfterFunc(0, func() { + e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + }) } } @@ -224,7 +230,7 @@ func (e *neighborEntry) cancelTimerLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() e.dispatchRemoveEventLocked() e.cancelTimerLocked() // TODO(https://gvisor.dev/issues/5583): test the case where this function is @@ -246,7 +252,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.mu.neigh.State e.mu.neigh.State = next - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() config := e.nudState.Config() switch next { @@ -307,7 +313,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -354,14 +360,14 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Unknown, Unreachable: prev := e.mu.neigh.State e.mu.neigh.State = Incomplete - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() switch prev { case Unknown: e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +384,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..59d86d6d4 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ @@ -354,10 +349,10 @@ func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes * EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -415,10 +410,10 @@ func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entry EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -446,7 +441,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { // UpdatedAt should remain the same during address resolution. e.mu.Lock() - startedAt := e.mu.neigh.UpdatedAtNanos + startedAt := e.mu.neigh.UpdatedAt e.mu.Unlock() // Wait for the rest of the reachability probe transmissions, signifying @@ -470,7 +465,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.mu.neigh.UpdatedAtNanos, startedAt; got != want { + if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want { t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -485,10 +480,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -547,10 +542,10 @@ func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -644,10 +639,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -678,10 +673,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -757,10 +752,10 @@ func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *tes EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -943,10 +938,10 @@ func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDis EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -998,10 +993,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1050,10 +1045,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1102,10 +1097,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1191,10 +1186,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1243,10 +1238,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1284,10 +1279,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1332,10 +1327,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1391,10 +1386,10 @@ func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTe EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + UpdatedAt: clock.Now(), }, }, } @@ -1443,10 +1438,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1498,10 +1493,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1553,10 +1548,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1645,10 +1640,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1697,10 +1692,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1770,10 +1765,10 @@ func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatc EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + UpdatedAt: clock.Now(), }, }, } @@ -1827,10 +1822,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1882,10 +1877,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -2003,10 +1998,10 @@ func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, lin EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: linkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: linkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -2155,10 +2150,10 @@ func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDD EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -2227,10 +2222,10 @@ func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkR EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dbba2c79f..378389db2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,7 +773,7 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } @@ -807,20 +794,20 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +839,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..ca9822bca 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -16,6 +16,7 @@ package stack import ( "math" + "math/rand" "sync" "time" @@ -313,45 +314,36 @@ func calcMaxRandomFactor(minRandomFactor float32) float32 { return defaultMaxRandomFactor } -// A Rand is a source of random numbers. -type Rand interface { - // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). - Float32() float32 -} - // NUDState stores states needed for calculating reachable time. type NUDState struct { - rng Rand + clock tcpip.Clock + rng *rand.Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex + mu struct { + sync.RWMutex - config NUDConfigurations + config NUDConfigurations - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified // random number generator for use in recomputing ReachableTime. -func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { +func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState { s := &NUDState{ - rng: rng, + clock: clock, + rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +351,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +369,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if s.clock.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +400,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = s.clock.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..1aeb2f8a5 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -16,6 +16,7 @@ package stack_test import ( "math" + "math/rand" "testing" "time" @@ -28,17 +29,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) @@ -48,12 +47,14 @@ type fakeRand struct { num float32 } -var _ stack.Rand = (*fakeRand)(nil) +var _ rand.Source = (*fakeRand)(nil) -func (f *fakeRand) Float32() float32 { - return f.num +func (f *fakeRand) Int63() int64 { + return int64(f.num * float32(1<<63)) } +func (*fakeRand) Seed(int64) {} + func TestNUDFunctions(t *testing.T) { const nicID = 1 @@ -169,7 +170,7 @@ func TestNUDFunctions(t *testing.T) { t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) } else if test.expectedErr == nil { if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Errorf("neighbors mismatch (-want +got):\n%s", diff) @@ -710,7 +711,8 @@ func TestNUDStateReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) if got, want := s.ReachableTime(), test.want; got != want { t.Errorf("got ReachableTime = %q, want = %q", got, want) } @@ -782,7 +784,8 @@ func TestNUDStateRecomputeReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) old := s.ReachableTime() if got, want := s.ReachableTime(), old; got != want { diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go index 421fb5c15..c8294eb6e 100644 --- a/pkg/tcpip/stack/rand.go +++ b/pkg/tcpip/stack/rand.go @@ -15,7 +15,7 @@ package stack import ( - mathrand "math/rand" + "math/rand" "gvisor.dev/gvisor/pkg/sync" ) @@ -23,7 +23,7 @@ import ( // lockedRandomSource provides a threadsafe rand.Source. type lockedRandomSource struct { mu sync.Mutex - src mathrand.Source + src rand.Source } func (r *lockedRandomSource) Int63() (n int64) { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8814f45a6..40d277312 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,17 +20,16 @@ package stack import ( - "bytes" "encoding/binary" "fmt" "io" - mathrand "math/rand" + "math/rand" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/atomicbitops" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -40,13 +39,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -145,7 +137,7 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. - randomGenerator *mathrand.Rand + randomGenerator *rand.Rand // secureRNG is a cryptographically secure random number generator. secureRNG io.Reader @@ -196,9 +188,9 @@ type Options struct { // TransportProtocols lists the transport protocols to enable. TransportProtocols []TransportProtocolFactory - // Clock is an optional clock source used for timestampping packets. + // Clock is an optional clock used for timekeeping. // - // If no Clock is specified, the clock source will be time.Now. + // If Clock is nil, tcpip.NewStdClock() will be used. Clock tcpip.Clock // Stats are optional statistic counters. @@ -225,15 +217,21 @@ type Options struct { // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data - // returned by rand.Read(). + // returned by the stack secure RNG. // // RandSource must be thread-safe. - RandSource mathrand.Source + RandSource rand.Source - // IPTables are the initial iptables rules. If nil, iptables will allow + // IPTables are the initial iptables rules. If nil, DefaultIPTables will be + // used to construct the initial iptables rules. // all traffic. IPTables *IPTables + // DefaultIPTables is an optional iptables rules constructor that is called + // if IPTables is nil. If both fields are nil, iptables will allow all + // traffic. + DefaultIPTables func(uint32) *IPTables + // SecureRNG is a cryptographically secure random number generator. SecureRNG io.Reader } @@ -331,23 +329,32 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + if opts.SecureRNG == nil { + opts.SecureRNG = cryptorand.Reader + } + randSrc := opts.RandSource if randSrc == nil { - // Source provided by mathrand.NewSource is not thread-safe so + var v int64 + if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil { + panic(err) + } + // Source provided by rand.NewSource is not thread-safe so // we wrap it in a simple thread-safe version. - randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + randSrc = &lockedRandomSource{src: rand.NewSource(v)} } + randomGenerator := rand.New(randSrc) + seed := randomGenerator.Uint32() if opts.IPTables == nil { - opts.IPTables = DefaultTables() + if opts.DefaultIPTables == nil { + opts.DefaultIPTables = DefaultTables + } + opts.IPTables = opts.DefaultIPTables(seed) } opts.NUDConfigs.resetInvalidFields() - if opts.SecureRNG == nil { - opts.SecureRNG = rand.Reader - } - s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -360,11 +367,11 @@ func New(opts Options) *Stack { handleLocal: opts.HandleLocal, tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), + seed: seed, nudConfigs: opts.NUDConfigs, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), + randomGenerator: randomGenerator, secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, @@ -804,7 +811,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -856,7 +863,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), @@ -1819,7 +1826,7 @@ func (s *Stack) Seed() uint32 { // Rand returns a reference to a pseudo random generator that can be used // to generate random numbers as required. -func (s *Stack) Rand() *mathrand.Rand { +func (s *Stack) Rand() *rand.Rand { return s.randomGenerator } @@ -1829,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader { return s.secureRNG } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - -func generateRandInt64() int64 { - b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - panic(err) - } - buf := bytes.NewReader(b) - var v int64 - if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { - panic(err) - } - return v -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index 33824afd0..dfec4258a 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,78 +14,6 @@ package stack -import "time" - // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack - -// saveT is invoked by stateify. -func (t *TCPCubicState) saveT() unixTime { - return unixTime{t.T.Unix(), t.T.UnixNano()} -} - -// loadT is invoked by stateify. -func (t *TCPCubicState) loadT(unix unixTime) { - t.T = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (t *TCPRACKState) saveXmitTime() unixTime { - return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (t *TCPRACKState) loadXmitTime(unix unixTime) { - t.XmitTime = time.Unix(unix.second, unix.nano) -} - -// saveLastSendTime is invoked by stateify. -func (t *TCPSenderState) saveLastSendTime() unixTime { - return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (t *TCPSenderState) loadLastSendTime(unix unixTime) { - t.LastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) saveRTTMeasureTime() unixTime { - return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { - t.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.MeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { - return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { - r.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveSegTime is invoked by stateify. -func (t *TCPEndpointState) saveSegTime() unixTime { - return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} -} - -// loadSegTime is invoked by stateify. -func (t *TCPEndpointState) loadSegTime(unix unixTime) { - t.SegTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 02d54d29b..21951d05a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -166,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -197,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -463,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -920,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -941,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1066,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1118,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1140,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1220,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1248,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1287,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1324,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1574,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1615,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1641,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1660,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1709,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2049,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2113,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2234,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2248,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } + + nicStats := s.NICInfo()[nicid].Stats - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2316,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -2603,15 +2640,17 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } dadConfigs := stack.DefaultDADConfigurations() + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, } e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2629,17 +2668,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { linkLocalAddr := header.LinkLocalAddr(linkAddr1) // Wait for DAD to resolve. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: // We should get a resolution event after 1s (default time to // resolve as per default NDP configurations). Waiting for that // resolution time + an extra 1s without a resolution event // means something is wrong. t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { t.Fatal(err) @@ -3270,8 +3310,9 @@ func TestDoDADWhenNICEnabled(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -3280,6 +3321,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3324,13 +3366,14 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(dadTransmits * retransmitTimer) select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) @@ -3837,8 +3880,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3875,8 +3916,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4223,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { @@ -4394,7 +4433,7 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) } else if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index ddff6e2d6..e90c1a770 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -39,7 +39,7 @@ type TCPCubicState struct { WMax float64 // T is the time when the current congestion avoidance was entered. - T time.Time `state:".(unixTime)"` + T tcpip.MonotonicTime // TimeSinceLastCongestion denotes the time since the current // congestion avoidance was entered. @@ -78,7 +78,7 @@ type TCPCubicState struct { type TCPRACKState struct { // XmitTime is the transmission timestamp of the most recent // acknowledged segment. - XmitTime time.Time `state:".(unixTime)"` + XmitTime tcpip.MonotonicTime // EndSequence is the ending TCP sequence number of the most recent // acknowledged segment. @@ -216,7 +216,7 @@ type TCPRTTState struct { // +stateify savable type TCPSenderState struct { // LastSendTime is the timestamp at which we sent the last segment. - LastSendTime time.Time `state:".(unixTime)"` + LastSendTime tcpip.MonotonicTime // DupAckCount is the number of Duplicate ACKs received. It is used for // fast retransmit. @@ -256,7 +256,7 @@ type TCPSenderState struct { RTTMeasureSeqNum seqnum.Value // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Closed indicates that the caller has closed the endpoint for // sending. @@ -313,7 +313,7 @@ type TCPSACKInfo struct { type RcvBufAutoTuneParams struct { // MeasureTime is the time at which the current measurement was // started. - MeasureTime time.Time `state:".(unixTime)"` + MeasureTime tcpip.MonotonicTime // CopiedBytes is the number of bytes copied to user space since this // measure began. @@ -341,7 +341,7 @@ type RcvBufAutoTuneParams struct { // RTTMeasureTime is the absolute time at which the current RTT // measurement period began. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Disabled is true if an explicit receive buffer is set for the // endpoint. @@ -429,7 +429,7 @@ type TCPEndpointState struct { ID TCPEndpointID // SegTime denotes the absolute time when this segment was received. - SegTime time.Time `state:".(unixTime)"` + SegTime tcpip.MonotonicTime // RcvBufState contains information about the state of the endpoint's // receive socket buffer. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 80ad1a9d4..8a8454a6a 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -16,8 +16,6 @@ package stack import ( "fmt" - "math/rand" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" @@ -223,7 +221,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -475,7 +473,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + seed: d.stack.Seed(), } eps.endpoints[id] = epsByNIC } @@ -502,7 +500,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum return nil } - return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) + return epsByNIC.checkEndpoint(flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 4848495c9..bdd245d5e 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -291,14 +291,17 @@ func TestBindToDeviceDistribution(t *testing.T) { wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) + t.Cleanup(func() { + wq.EventUnregister(&we) + close(ch) + }) var err tcpip.Error ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) } + t.Cleanup(ep.Close) eps[ep] = i go func(ep tcpip.Endpoint) { @@ -307,7 +310,6 @@ func TestBindToDeviceDistribution(t *testing.T) { } }(ep) - defer ep.Close() ep.SocketOptions().SetReusePort(endpoint.reuse) if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) @@ -332,7 +334,7 @@ func TestBindToDeviceDistribution(t *testing.T) { if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } - ports := make(map[uint16]tcpip.Endpoint) + endpoints := make(map[uint16]tcpip.Endpoint) stats := make(map[tcpip.Endpoint]int) for i := 0; i < npackets; i++ { // Send a packet. @@ -357,11 +359,11 @@ func TestBindToDeviceDistribution(t *testing.T) { } stats[ep]++ if i < nports { - ports[uint16(i)] = ep + endpoints[uint16(i)] = ep } else { // Check that all packets from one client are handled by the same // socket. - if want, got := ports[port], ep; want != got { + if want, got := endpoints[port], ep; want != got { t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) } } diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go index 7ce43a68e..371da2f40 100644 --- a/pkg/tcpip/stdclock.go +++ b/pkg/tcpip/stdclock.go @@ -60,11 +60,11 @@ type stdClock struct { // monotonicOffset is assigned maxMonotonic after restore so that the // monotonic time will continue from where it "left off" before saving as part // of S/R. - monotonicOffset int64 `state:"nosave"` + monotonicOffset MonotonicTime `state:"nosave"` // monotonicMU protects maxMonotonic. monotonicMU sync.Mutex `state:"nosave"` - maxMonotonic int64 + maxMonotonic MonotonicTime } // NewStdClock returns an instance of a clock that uses the time package. @@ -76,25 +76,25 @@ func NewStdClock() Clock { var _ Clock = (*stdClock)(nil) -// NowNanoseconds implements Clock.NowNanoseconds. -func (*stdClock) NowNanoseconds() int64 { - return time.Now().UnixNano() +// Now implements Clock.Now. +func (*stdClock) Now() time.Time { + return time.Now() } // NowMonotonic implements Clock.NowMonotonic. -func (s *stdClock) NowMonotonic() int64 { +func (s *stdClock) NowMonotonic() MonotonicTime { sinceBase := time.Since(s.baseTime) if sinceBase < 0 { panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime)) } - monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset + monotonicValue := s.monotonicOffset.Add(sinceBase) s.monotonicMU.Lock() defer s.monotonicMU.Unlock() // Monotonic time values must never decrease. - if monotonicValue > s.maxMonotonic { + if s.maxMonotonic.Before(monotonicValue) { s.maxMonotonic = monotonicValue } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 797778e08..91622fa4c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -64,17 +64,47 @@ func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } +// MonotonicTime is a monotonic clock reading. +// +// +stateify savable +type MonotonicTime struct { + nanoseconds int64 +} + +// Before reports whether the monotonic clock reading mt is before u. +func (mt MonotonicTime) Before(u MonotonicTime) bool { + return mt.nanoseconds < u.nanoseconds +} + +// After reports whether the monotonic clock reading mt is after u. +func (mt MonotonicTime) After(u MonotonicTime) bool { + return mt.nanoseconds > u.nanoseconds +} + +// Add returns the monotonic clock reading mt+d. +func (mt MonotonicTime) Add(d time.Duration) MonotonicTime { + return MonotonicTime{ + nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(), + } +} + +// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum) +// value that can be stored in a Duration, the maximum (or minimum) duration +// will be returned. To compute t-d for a duration d, use t.Add(-d). +func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration { + return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds)) +} + // A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. type Clock interface { - // NowNanoseconds returns the current real time as a number of - // nanoseconds since the Unix epoch. - NowNanoseconds() int64 + // Now returns the current local time. + Now() time.Time - // NowMonotonic returns a monotonic time value at nanosecond resolution. - NowMonotonic() int64 + // NowMonotonic returns the current monotonic clock reading. + NowMonotonic() MonotonicTime // AfterFunc waits for the duration to elapse and then calls f in its own // goroutine. It returns a Timer that can be used to cancel the call using @@ -861,6 +891,9 @@ type SettableSocketOption interface { isSettableSocketOption() } +// EndpointState represents the state of an endpoint. +type EndpointState uint8 + // CongestionControlState indicates the current congestion control state for // TCP sender. type CongestionControlState int @@ -897,6 +930,9 @@ type TCPInfoOption struct { // RTO is the retransmission timeout for the endpoint. RTO time.Duration + // State is the current endpoint protocol state. + State EndpointState + // CcState is the congestion control state. CcState CongestionControlState @@ -1552,6 +1588,10 @@ type IPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig *StatCounter + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable *StatCounter + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -1835,37 +1875,104 @@ type UDPStats struct { ChecksumErrors *StatCounter } +// NICNeighborStats holds metrics for the neighbor table. +type NICNeighborStats struct { + // LINT.IfChange(NICNeighborStats) + + // UnreachableEntryLookups counts the number of lookups performed on an + // entry in Unreachable state. + UnreachableEntryLookups *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICNeighborStats) +} + +// NICPacketStats holds basic packet statistics. +type NICPacketStats struct { + // LINT.IfChange(NICPacketStats) + + // Packets is the number of packets counted. + Packets *StatCounter + + // Bytes is the number of bytes counted. + Bytes *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICPacketStats) +} + +// NICStats holds NIC statistics. +type NICStats struct { + // LINT.IfChange(NICStats) + + // UnknownL3ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported network protocol. + UnknownL3ProtocolRcvdPackets *StatCounter + + // UnknownL4ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported transport protocol. + UnknownL4ProtocolRcvdPackets *StatCounter + + // MalformedL4RcvdPackets is the number of packets received by a NIC that + // could not be delivered to a transport endpoint because the L4 header could + // not be parsed. + MalformedL4RcvdPackets *StatCounter + + // Tx contains statistics about transmitted packets. + Tx NICPacketStats + + // Rx contains statistics about received packets. + Rx NICPacketStats + + // DisabledRx contains statistics about received packets on disabled NICs. + DisabledRx NICPacketStats + + // Neighbor contains statistics about neighbor entries. + Neighbor NICNeighborStats + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICStats) +} + +// FillIn returns a copy of s with nil fields initialized to new StatCounters. +func (s NICStats) FillIn() NICStats { + InitStatCounters(reflect.ValueOf(&s).Elem()) + return s +} + // Stats holds statistics about the networking stack. -// -// All fields are optional. type Stats struct { - // UnknownProtocolRcvdPackets is the number of packets received by the - // stack that were for an unknown or unsupported protocol. - UnknownProtocolRcvdPackets *StatCounter + // TODO(https://gvisor.dev/issues/5986): Make the DroppedPackets stat less + // ambiguous. - // MalformedRcvdPackets is the number of packets received by the stack - // that were deemed malformed. - MalformedRcvdPackets *StatCounter - - // DroppedPackets is the number of packets dropped due to full queues. + // DroppedPackets is the number of packets dropped at the transport layer. DroppedPackets *StatCounter - // ICMP breaks out ICMP-specific stats (both v4 and v6). + // NICs is an aggregation of every NIC's statistics. These should not be + // incremented using this field, but using the relevant NIC multicounters. + NICs NICStats + + // ICMP is an aggregation of every NetworkEndpoint's ICMP statistics (both v4 + // and v6). These should not be incremented using this field, but using the + // relevant NetworkEndpoint ICMP multicounters. ICMP ICMPStats - // IGMP breaks out IGMP-specific stats. + // IGMP is an aggregation of every NetworkEndpoint's IGMP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint IGMP multicounters. IGMP IGMPStats - // IP breaks out IP-specific stats (both v4 and v6). + // IP is an aggregation of every NetworkEndpoint's IP statistics. These should + // not be incremented using this field, but using the relevant NetworkEndpoint + // IP multicounters. IP IPStats - // ARP breaks out ARP-specific stats. + // ARP is an aggregation of every NetworkEndpoint's ARP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint ARP multicounters. ARP ARPStats - // TCP breaks out TCP-specific stats. + // TCP holds TCP-specific stats. TCP TCPStats - // UDP breaks out UDP-specific stats. + // UDP holds UDP-specific stats. UDP UDPStats } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 269081ff8..c96ae2f02 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "net" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -210,26 +209,6 @@ func TestAddressString(t *testing.T) { } } -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - func TestAddressWithPrefixSubnet(t *testing.T) { tests := []struct { addr Address diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index ab2dab60c..181ef799e 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -53,12 +53,14 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index c657714ba..27caa0c28 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -17,6 +17,8 @@ package link_resolution_test import ( "bytes" "fmt" + "net" + "runtime" "testing" "time" @@ -27,12 +29,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -283,9 +287,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, } host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -329,7 +335,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) { // Wait for an error due to link resolution failing, or the endpoint to be // writable. + if test.expectedWriteErr != nil { + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + } else { + clock.RunImmediatelyScheduledJobs() + } <-ch + { var r bytes.Reader r.Reset([]byte{0}) @@ -395,6 +411,242 @@ func TestTCPLinkResolutionFailure(t *testing.T) { } } +func TestForwardingWithLinkResolutionFailure(t *testing.T) { + const ( + incomingNICID = 1 + outgoingNICID = 2 + ttl = 2 + expectedHostUnreachableErrorCount = 1 + ) + outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != arp.ProtocolNumber { + t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) + } + if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(request.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) + } + } + + ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) + } + + snmc := header.SolicitedNodeAddr(dst) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), + checker.SrcAddr(src), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(dst), + )) + } + + icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4HostUnreachable), + ), + ) + } + + icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv6.DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6AddressUnreachable), + ), + ) + } + + tests := []struct { + name string + networkProtocolFactory []stack.NetworkProtocolFactory + networkProtocolNumber tcpip.NetworkProtocolNumber + sourceAddr tcpip.Address + destAddr tcpip.Address + incomingAddr tcpip.AddressWithPrefix + outgoingAddr tcpip.AddressWithPrefix + transportProtocol func(*stack.Stack) stack.TransportProtocol + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) + icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) + mtu uint32 + }{ + { + name: "IPv4 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + networkProtocolNumber: header.IPv4ProtocolNumber, + sourceAddr: tcptestutil.MustParse4("10.0.0.2"), + destAddr: tcptestutil.MustParse4("11.0.0.2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + }, + transportProtocol: icmp.NewProtocol4, + linkResolutionRequestChecker: arpChecker, + icmpReplyChecker: icmpv4Checker, + rx: rxICMPv4EchoRequest, + mtu: ipv4.MaxTotalSize, + }, + { + name: "IPv6 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + networkProtocolNumber: header.IPv6ProtocolNumber, + sourceAddr: tcptestutil.MustParse6("10::2"), + destAddr: tcptestutil.MustParse6("11::2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + }, + transportProtocol: icmp.NewProtocol6, + linkResolutionRequestChecker: ndpChecker, + icmpReplyChecker: icmpv6Checker, + rx: rxICMPv6EchoRequest, + mtu: header.IPv6MinimumMTU, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + + s := stack.New(stack.Options{ + NetworkProtocols: test.networkProtocolFactory, + TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, + Clock: clock, + }) + + // Set up endpoint through which we will receive packets. + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) + } + incomingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.incomingAddr, + } + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + } + + // Set up endpoint through which we will attempt to forward packets. + outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) + outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) + } + outgoingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.outgoingAddr, + } + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.incomingAddr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: test.outgoingAddr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) + } + + test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) + + nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) + if err != nil { + t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) + } + // Trigger the first packet on the endpoint. + clock.RunImmediatelyScheduledJobs() + + for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { + request, ok := outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ARP packet through outgoing NIC") + } + + test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) + + // Advance the clock the span of one request timeout. + clock.Advance(nudConfigs.RetransmitTimer) + } + + // Next, we make a blocking read to retrieve the error packet. This is + // necessary because outgoing packets are dequeued asynchronously when + // link resolution fails, and this dequeue is what triggers the ICMP + // error. + reply, ok := incomingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP packet through incoming NIC") + } + + test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) + + // Since link resolution failed, we don't expect the packet to be + // forwarded. + forwardedPacket, ok := outgoingEndpoint.Read() + if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) + } + + if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { + t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) + } + }) + } +} + func TestGetLinkAddress(t *testing.T) { const ( host1NICID = 1 @@ -449,8 +701,10 @@ func TestGetLinkAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + Clock: clock, } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -466,8 +720,20 @@ func TestGetLinkAddress(t *testing.T) { if test.expectedErr == nil { wantRes.LinkAddress = utils.LinkAddr2 } - if diff := cmp.Diff(wantRes, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + select { + case got := <-ch: + if diff := cmp.Diff(wantRes, got); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("event didn't arrive") } }) } @@ -544,8 +810,10 @@ func TestRouteResolvedFields(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + Clock: clock, } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -575,8 +843,20 @@ func TestRouteResolvedFields(t *testing.T) { if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + + select { + case got := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + } + default: + t.Fatalf("event didn't arrive") } if test.expectedErr != nil { @@ -789,11 +1069,16 @@ func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo d.c <- e } -func (d *nudDispatcher) waitForEvent(want eventInfo) error { - if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) +func (d *nudDispatcher) expectEvent(want eventInfo) error { + select { + case got := <-d.c: + if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil + default: + return fmt.Errorf("event didn't arrive") } - return nil } // TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it @@ -804,7 +1089,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address neighborAddr tcpip.Address - getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) isHost1Listener bool }{ { @@ -812,23 +1097,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -836,23 +1123,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -860,23 +1149,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -884,23 +1175,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -908,23 +1201,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -933,23 +1228,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -958,23 +1255,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -983,23 +1282,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -1037,14 +1338,14 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) } } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryAdded, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, }); err != nil { t.Fatalf("error waiting for initial NUD event: %s", err) } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1064,7 +1365,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // See NUDConfigurations.BaseReachableTime for more information. maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1072,8 +1373,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for stale NUD event: %s", err) } - listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) - defer listenerEP.Close() + listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) defer clientEP.Close() listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} if err := listenerEP.Bind(listenerAddr); err != nil { @@ -1094,14 +1394,15 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // with confirmation that the neighbor is reachable (indicated by a // successful 3-way handshake). <-clientCH - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, }); err != nil { t.Fatalf("error waiting for delay NUD event: %s", err) } - if err := nudDisp.waitForEvent(eventInfo{ + <-listenerCH + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1109,26 +1410,55 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for reachable NUD event: %s", err) } + peerEP, peerWQ, err := listenerEP.Accept(nil) + if err != nil { + t.Fatalf("listenerEP.Accept(): %s", err) + } + defer peerEP.Close() + peerWE, peerCH := waiter.NewChannelEntry(nil) + peerWQ.EventRegister(&peerWE, waiter.ReadableEvents) + // Wait for the neighbor to be stale again then send data to the remote. // // On successful transmission, the neighbor should become reachable // without probing the neighbor as a TCP ACK would be received which is an // indication of the neighbor being reachable. clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, }); err != nil { t.Fatalf("error waiting for stale NUD event: %s", err) } - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := clientEP.Write(&r, wOpts); err != nil { - t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + } + // Heads up, there is a race here. + // + // Incoming TCP segments are handled in + // tcp.(*endpoint).handleSegmentLocked: + // + // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the + // segment queue and notifies waiting readers (such as this channel) + // + // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment + // and notifies the NUD machinery that the peer is reachable + // + // Thus we must permit a delay between the readable signal and the + // expected NUD event. + // + // At the time of writing, this race is reliably hit with gotsan. + <-peerCH + for len(nudDisp.c) == 0 { + runtime.Gosched() } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1141,7 +1471,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // TCP should not mark the route reachable and NUD should go through the // probe state. clock.Advance(nudConfigs.DelayFirstProbeTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1149,7 +1479,16 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for probe NUD event: %s", err) } } - if err := nudDisp.waitForEvent(eventInfo{ + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := peerEP.Write(&r, wOpts); err != nil { + t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err) + } + } + <-clientCH + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go index f84d399fb..94b580a70 100644 --- a/pkg/tcpip/testutil/testutil.go +++ b/pkg/tcpip/testutil/testutil.go @@ -109,3 +109,15 @@ func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) er return nil } + +// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on +// error. +// +// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. +func MustParseLink(addr string) tcpip.LinkAddress { + parsed, err := tcpip.ParseMACAddress(addr) + if err != nil { + panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) + } + return parsed +} diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/tcpip/time.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index 1633d0aeb..ed1ed8ac6 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -15,6 +15,7 @@ package tcpip_test import ( + "math" "sync" "testing" "time" @@ -22,10 +23,85 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) +func TestMonotonicTimeBefore(t *testing.T) { + var mt tcpip.MonotonicTime + if mt.Before(mt) { + t.Errorf("%#v.Before(%#v)", mt, mt) + } + + one := mt.Add(1) + if one.Before(mt) { + t.Errorf("%#v.Before(%#v)", one, mt) + } + if !mt.Before(one) { + t.Errorf("!%#v.Before(%#v)", mt, one) + } +} + +func TestMonotonicTimeAfter(t *testing.T) { + var mt tcpip.MonotonicTime + if mt.After(mt) { + t.Errorf("%#v.After(%#v)", mt, mt) + } + + one := mt.Add(1) + if mt.After(one) { + t.Errorf("%#v.After(%#v)", mt, one) + } + if !one.After(mt) { + t.Errorf("!%#v.After(%#v)", one, mt) + } +} + +func TestMonotonicTimeAddSub(t *testing.T) { + var mt tcpip.MonotonicTime + if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { + t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) + } + + min := mt.Add(math.MinInt64) + max := mt.Add(math.MaxInt64) + + if overflow := mt.Add(1).Add(math.MaxInt64); overflow != max { + t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) + } + if underflow := mt.Add(-1).Add(math.MinInt64); underflow != min { + t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", min, underflow) + } + + if got, want := min.Sub(min), time.Duration(0); want != got { + t.Errorf("got min.Sub(min) = %d, want %d", got, want) + } + if got, want := max.Sub(max), time.Duration(0); want != got { + t.Errorf("got max.Sub(max) = %d, want %d", got, want) + } + + if overflow, want := max.Sub(min), time.Duration(math.MaxInt64); overflow != want { + t.Errorf("mt.Add(math.MaxInt64).Sub(mt.Add(math.MinInt64) != %s (%#v)", want, overflow) + } + if underflow, want := min.Sub(max), time.Duration(math.MinInt64); underflow != want { + t.Errorf("mt.Add(math.MinInt64).Sub(mt.Add(math.MaxInt64) != %s (%#v)", want, underflow) + } +} + +func TestMonotonicTimeSub(t *testing.T) { + var mt tcpip.MonotonicTime + + if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { + t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) + } + + if max, overflow := mt.Add(math.MaxInt64), mt.Add(1).Add(math.MaxInt64); max != overflow { + t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) + } + if max, underflow := mt.Add(math.MinInt64), mt.Add(-1).Add(math.MinInt64); max != underflow { + t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", max, underflow) + } +} + const ( shortDuration = 1 * time.Nanosecond middleDuration = 100 * time.Millisecond - longDuration = 1 * time.Second ) func TestJobReschedule(t *testing.T) { @@ -53,10 +129,14 @@ func TestJobReschedule(t *testing.T) { wg.Wait() } +func stdClockWithAfter() (tcpip.Clock, func(time.Duration) <-chan time.Time) { + return tcpip.NewStdClock(), time.After +} + func TestJobExecution(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -68,7 +148,7 @@ func TestJobExecution(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -76,14 +156,14 @@ func TestJobExecution(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -99,7 +179,7 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -107,14 +187,14 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -128,7 +208,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): + case <-after(middleDuration): } job.Schedule(shortDuration) @@ -136,7 +216,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -144,14 +224,14 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -167,14 +247,19 @@ func TestJobImmediatelyCancel(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): + case <-after(middleDuration): } } +func stdClockWithAfterAndSleep() (tcpip.Clock, func(time.Duration) <-chan time.Time, func(time.Duration)) { + clock, after := stdClockWithAfter() + return clock, after, time.Sleep +} + func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after, sleep := stdClockWithAfterAndSleep() var lock sync.Mutex ch := make(chan struct{}) @@ -189,7 +274,7 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration * 2) + sleep(middleDuration * 2) job.Cancel() lock.Unlock() } @@ -199,14 +284,14 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration * 2): + case <-after(middleDuration * 2): } } func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after, sleep := stdClockWithAfterAndSleep() var lock sync.Mutex ch := make(chan struct{}) @@ -215,7 +300,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration) + sleep(middleDuration) job.Cancel() job.Schedule(shortDuration) } @@ -224,7 +309,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { // Wait for double the duration for the last timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -232,14 +317,14 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -255,7 +340,7 @@ func TestManyJobReschedulesUnderLock(t *testing.T) { // Wait for double the duration for the last timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -263,6 +348,6 @@ func TestManyJobReschedulesUnderLock(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8afde7fca..39f526023 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -16,6 +16,7 @@ package icmp import ( "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -33,7 +34,7 @@ type icmpPacket struct { icmpPacketEntry senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + receivedAt time.Time `state:".(int64)"` } type endpointState int @@ -160,7 +161,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { @@ -193,7 +194,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: p.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.timestamp, + Timestamp: p.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -349,7 +350,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { +func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { return nil } @@ -390,7 +391,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -606,7 +607,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. @@ -615,7 +616,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) switch err.(type) { @@ -800,7 +801,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 28a56a2d5..b8b839e4a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,23 @@ package icmp import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *icmpPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *icmpPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves icmpPacket.data field. func (p *icmpPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 496eca581..cd8c99d41 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -27,6 +27,7 @@ package packet import ( "fmt" "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,9 +42,8 @@ type packet struct { packetEntry // data holds the actual packet data, including any headers and // payload. - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // timestampNS is the unix time at which the packet was received. - timestampNS int64 + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress // packetInfo holds additional information like the protocol @@ -159,7 +159,7 @@ func (ep *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -189,7 +189,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul Total: packet.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: packet.timestampNS, + Timestamp: packet.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -220,19 +220,19 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes *tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { +func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error { return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -318,7 +318,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -339,7 +339,7 @@ func (ep *endpoint) UpdateLastError(err tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -451,7 +451,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } } - packet.timestampNS = ep.stack.Clock().NowNanoseconds() + packet.receivedAt = ep.stack.Clock().Now() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() @@ -484,7 +484,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { } // SetOwner implements tcpip.Endpoint.SetOwner. -func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*endpoint) SetOwner(tcpip.PacketOwner) {} // SocketOptions implements tcpip.Endpoint.SocketOptions. func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 5bd860d20..e729921db 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,11 +15,23 @@ package packet import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *packet) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *packet) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves packet.data field. func (p *packet) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index bcec3d2e7..1bce2769a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -27,6 +27,7 @@ package raw import ( "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,9 +42,8 @@ type rawPacket struct { rawPacketEntry // data holds the actual packet data, including any headers and // payload. - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // timestampNS is the unix time at which the packet was received. - timestampNS int64 + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress } @@ -183,7 +183,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.mu.Lock() @@ -219,7 +219,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: pkt.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: pkt.timestampNS, + Timestamp: pkt.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -402,7 +402,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Find a route to the destination. - route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) + route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) if err != nil { return err } @@ -428,7 +428,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -439,7 +439,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -513,12 +513,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -621,7 +621,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV - packet.timestampNS = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 5d6f2709c..39669b445 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,11 +15,23 @@ package raw import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *rawPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *rawPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves rawPacket.data field. func (p *rawPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 0f20d3856..8436d2cf0 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -41,7 +41,6 @@ go_library( "protocol.go", "rack.go", "rcv.go", - "rcv_state.go", "reno.go", "reno_recovery.go", "sack.go", @@ -53,7 +52,6 @@ go_library( "segment_state.go", "segment_unsafe.go", "snd.go", - "snd_state.go", "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", @@ -134,6 +132,7 @@ go_test( deps = [ "//pkg/sleep", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d4bd4e80e..2c65b737d 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -114,8 +114,8 @@ type listenContext struct { } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. -func timeStamp() uint32 { - return uint32(time.Now().Unix()>>6) & tsMask +func timeStamp(clock tcpip.Clock) uint32 { + return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Seconds()) >> 6 & tsMask } // newListenContext creates a new listen context. @@ -171,7 +171,7 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // createCookie creates a SYN cookie for the given id and incoming sequence // number. func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { - ts := timeStamp() + ts := timeStamp(l.stack.Clock()) v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) v += (l.cookieHash(id, ts, 1) + data) & hashMask return seqnum.Value(v) @@ -181,7 +181,7 @@ func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Va // sequence number. If it is, it also returns the data originally encoded in the // cookie when createCookie was called. func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { - ts := timeStamp() + ts := timeStamp(l.stack.Clock()) v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) cookieTS := v >> tsOffset if ((ts - cookieTS) & tsMask) > maxTSDiff { @@ -247,7 +247,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - isn := generateSecureISN(s.id, l.stack.Seed()) + isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed()) ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err @@ -550,7 +550,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed e.rcvQueueInfo.rcvQueueMu.Unlock() - if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + if rcvClosed || s.flags.Contains(header.TCPFlagSyn|header.TCPFlagAck) { // If the endpoint is shutdown, reply with reset. // // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST @@ -591,7 +591,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), + TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())), TSEcr: opts.TSVal, MSS: calculateAdvertisedMSS(e.userMSS, route), } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5e03e7715..570e5081c 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -19,7 +19,6 @@ import ( "math" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -92,7 +91,7 @@ type handshake struct { rcvWndScale int // startTime is the time at which the first SYN/SYN-ACK was sent. - startTime time.Time + startTime tcpip.MonotonicTime // deferAccept if non-zero will drop the final ACK for a passive // handshake till an ACK segment with data is received or the timeout is @@ -147,21 +146,16 @@ func FindWndScale(wnd seqnum.Size) int { // resetState resets the state of the handshake object such that it becomes // ready for a new 3-way handshake. func (h *handshake) resetState() { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - h.state = handshakeSynSent h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed()) } // generateSecureISN generates a secure Initial Sequence number based on the // recommendation here https://tools.ietf.org/html/rfc6528#page-3. -func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { +func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value { isnHasher := jenkins.Sum32(seed) isnHasher.Write([]byte(id.LocalAddress)) isnHasher.Write([]byte(id.RemoteAddress)) @@ -180,7 +174,7 @@ func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { // // Which sort of guarantees that we won't reuse the ISN for a new // connection for the same tuple for at least 274s. - isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + isn := isnHasher.Sum32() + uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Nanoseconds()>>6) return seqnum.Value(isn) } @@ -212,7 +206,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea // a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in // response. func (h *handshake) checkAck(s *segment) bool { - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber != h.iss+1 { // RFC 793, page 36, states that a reset must be generated when // the connection is in any non-synchronized state and an // incoming segment acknowledges something not yet sent. The @@ -230,8 +224,8 @@ func (h *handshake) checkAck(s *segment) bool { func (h *handshake) synSentState(s *segment) tcpip.Error { // RFC 793, page 37, states that in the SYN-SENT state, a reset is // acceptable if the ack field acknowledges the SYN. - if s.flagIsSet(header.TCPFlagRst) { - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + if s.flags.Contains(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == h.iss+1 { // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK // was acceptable then signal the user "error: connection reset", drop // the segment, enter CLOSED state, delete TCB, and return." @@ -249,7 +243,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // We are in the SYN-SENT state. We only care about segments that have // the SYN flag. - if !s.flagIsSet(header.TCPFlagSyn) { + if !s.flags.Contains(header.TCPFlagSyn) { return nil } @@ -270,7 +264,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // If this is a SYN ACK response, we only need to acknowledge the SYN // and the handshake is completed. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { h.state = handshakeCompleted h.ep.transitionToStateEstablishedLocked(h) @@ -316,7 +310,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // synRcvdState handles a segment received when the TCP 3-way handshake is in // the SYN-RCVD state. func (h *handshake) synRcvdState(s *segment) tcpip.Error { - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { // RFC 793, page 37, states that in the SYN-RCVD state, a reset // is acceptable if the sequence number is in the window. if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { @@ -340,13 +334,13 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { return nil } - if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { + if s.flags.Contains(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { // We received two SYN segments with different sequence // numbers, so we reset this and restart the whole // process, except that we don't reset the timer. ack := s.sequenceNumber.Add(s.logicalLen()) seq := seqnum.Value(0) - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { seq = s.ackNumber } h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) @@ -378,10 +372,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { // If deferAccept is not zero and this is a bare ACK and the // timeout is not hit then drop the ACK. - if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + if h.deferAccept != 0 && s.data.Size() == 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) < h.deferAccept { h.acked = true h.ep.stack.Stats().DroppedPackets.Increment() return nil @@ -426,7 +420,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { func (h *handshake) handleSegment(s *segment) tcpip.Error { h.sndWnd = s.window - if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { + if !s.flags.Contains(header.TCPFlagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) } @@ -474,7 +468,7 @@ func (h *handshake) processSegments() tcpip.Error { // start sends the first SYN/SYN-ACK. It does not block, even if link address // resolution is required. func (h *handshake) start() { - h.startTime = time.Now() + h.startTime = h.ep.stack.Clock().NowMonotonic() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { @@ -527,7 +521,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -552,7 +546,7 @@ func (h *handshake) complete() tcpip.Error { // The last is required to provide a way for the peer to complete // the connection with another ACK or data (as ACKs are never // retransmitted on their own). - if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + if h.active || !h.acked || h.deferAccept != 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) > h.deferAccept { h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, @@ -608,15 +602,15 @@ func (h *handshake) complete() tcpip.Error { type backoffTimer struct { timeout time.Duration maxTimeout time.Duration - t *time.Timer + t tcpip.Timer } -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { +func newBackoffTimer(clock tcpip.Clock, timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { if timeout > maxTimeout { return nil, &tcpip.ErrTimeout{} } bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} - bt.t = time.AfterFunc(timeout, f) + bt.t = clock.AfterFunc(timeout, f) return bt, nil } @@ -634,7 +628,7 @@ func (bt *backoffTimer) stop() { } func parseSynSegmentOptions(s *segment) header.TCPSynOptions { - synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) + synOpts := header.ParseSynOptions(s.options, s.flags.Contains(header.TCPFlagAck)) if synOpts.TS { s.parsedOptions.TSVal = synOpts.TSVal s.parsedOptions.TSEcr = synOpts.TSEcr @@ -1188,11 +1182,11 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // the TCPEndpointState after the segment is processed. defer e.probeSegmentLocked() - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { return false, err } - } else if s.flagIsSet(header.TCPFlagSyn) { + } else if s.flags.Contains(header.TCPFlagSyn) { // See: https://tools.ietf.org/html/rfc5961#section-4.1 // 1) If the SYN bit is set, irrespective of the sequence number, TCP // MUST send an ACK (also referred to as challenge ACK) to the remote @@ -1216,7 +1210,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // should then rely on SYN retransmission from the remote end to // re-establish the connection. e.snd.maybeSendOutOfWindowAck(s) - } else if s.flagIsSet(header.TCPFlagAck) { + } else if s.flags.Contains(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. s.window <<= e.snd.SndWndScale @@ -1235,7 +1229,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - state := e.state + state := e.EndpointState() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. @@ -1267,7 +1261,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error { // If a userTimeout is set then abort the connection if it is // exceeded. - if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + if userTimeout != 0 && e.stack.Clock().NowMonotonic().Sub(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() return &tcpip.ErrTimeout{} @@ -1322,7 +1316,7 @@ func (e *endpoint) disableKeepaliveTimer() { // segments. func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { e.mu.Lock() - var closeTimer *time.Timer + var closeTimer tcpip.Timer var closeWaker sleep.Waker epilogue := func() { @@ -1484,7 +1478,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) } } @@ -1721,7 +1715,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { var timeWaitWaker sleep.Waker s.AddWaker(&timeWaitWaker, timeWaitDone) - timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert) defer timeWaitTimer.Stop() for { diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 962f1d687..6985194bb 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -41,7 +41,7 @@ type cubicState struct { func newCubicCC(s *sender) *cubicState { return &cubicState{ TCPCubicState: stack.TCPCubicState{ - T: time.Now(), + T: s.ep.stack.Clock().NowMonotonic(), Beta: 0.7, C: 0.4, }, @@ -60,7 +60,7 @@ func (c *cubicState) enterCongestionAvoidance() { // https://tools.ietf.org/html/rfc8312#section-4.8 if c.numCongestionEvents == 0 { c.K = 0 - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) } @@ -115,14 +115,15 @@ func (c *cubicState) cubicCwnd(t float64) float64 { // getCwnd returns the current congestion window as computed by CUBIC. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { - elapsed := time.Since(c.T).Seconds() + elapsed := c.s.ep.stack.Clock().NowMonotonic().Sub(c.T) + elapsedSeconds := elapsed.Seconds() // Compute the window as per Cubic after 'elapsed' time // since last congestion event. - c.WC = c.cubicCwnd(elapsed - c.K) + c.WC = c.cubicCwnd(elapsedSeconds - c.K) // Compute the TCP friendly estimate of the congestion window. - c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds()) + c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsedSeconds/srtt.Seconds()) // Make sure in the TCP friendly region CUBIC performs at least // as well as Reno. @@ -134,7 +135,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int // In Concave/Convex region of CUBIC, calculate what CUBIC window // will be after 1 RTT and use that to grow congestion window // for every ack. - tEst := (time.Since(c.T) + srtt).Seconds() + tEst := (elapsed + srtt).Seconds() wtRtt := c.cubicCwnd(tEst - c.K) // As per 4.3 for each received ACK cwnd must be incremented // by (w_cubic(t+RTT) - cwnd/cwnd. @@ -151,7 +152,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) @@ -162,7 +163,7 @@ func (c *cubicState) HandleLossDetected() { // HandleRTOExpired implements congestionContrl.HandleRTOExpired. func (c *cubicState) HandleRTOExpired() { // See: https://tools.ietf.org/html/rfc8312#section-4.6 - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.numCongestionEvents = 0 c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) @@ -193,7 +194,7 @@ func (c *cubicState) fastConvergence() { // PostRecovery implemements congestionControl.PostRecovery. func (c *cubicState) PostRecovery() { - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() } // reduceSlowStartThreshold returns new SsThresh as described in diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 512053a04..dff7cb89c 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -16,10 +16,11 @@ package tcp import ( "encoding/binary" + "math/rand" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -141,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) { // in-order. type dispatcher struct { processors []processor - seed uint32 - wg sync.WaitGroup + // seed is a random secret for a jenkins hash. + seed uint32 + wg sync.WaitGroup } -func (d *dispatcher) init(nProcessors int) { +func (d *dispatcher) init(rng *rand.Rand, nProcessors int) { d.close() d.wait() d.processors = make([]processor, nProcessors) - d.seed = generateRandUint32() + d.seed = rng.Uint32() for i := range d.processors { p := &d.processors[i] p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) @@ -172,12 +174,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, clock, pkt) if !s.parse(pkt.RXTransportChecksumValidated) { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() @@ -185,7 +186,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans } if !s.csumValid { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.ChecksumErrors.Increment() ep.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() @@ -213,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans d.selectProcessor(id).queueEndpoint(ep) } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { var payload [4]byte binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f148d505d..5342aacfd 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -421,7 +421,7 @@ func testV4Accept(t *testing.T, c *context.Context) { r.Reset(data) nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) + tcp = header.IPv4(b).Payload() if string(tcp.Payload()) != data { t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 50d39cbad..a27e2110b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -20,12 +20,12 @@ import ( "fmt" "io" "math" + "math/rand" "runtime" "strings" "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,19 +38,15 @@ import ( ) // EndpointState represents the state of a TCP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - // Endpoint states internal to netstack. These map to the TCP state CLOSED. - StateInitial EndpointState = iota - StateBound - StateConnecting // Connect() called, but the initial SYN hasn't been sent. - StateError - - // TCP protocol states. + _ EndpointState = iota + // TCP protocol states in sync with the definitions in + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/tcp_states.h#L13 StateEstablished StateSynSent StateSynRecv @@ -62,6 +58,12 @@ const ( StateLastAck StateListen StateClosing + + // Endpoint states internal to netstack. + StateInitial + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError ) const ( @@ -97,6 +99,16 @@ func (s EndpointState) connecting() bool { } } +// internal returns true when the state is netstack internal. +func (s EndpointState) internal() bool { + switch s { + case StateInitial, StateBound, StateConnecting, StateError: + return true + default: + return false + } +} + // handshake returns true when s is one of the states representing an endpoint // in the middle of a TCP handshake. func (s EndpointState) handshake() bool { @@ -422,12 +434,12 @@ type endpoint struct { // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState `state:".(EndpointState)"` + state uint32 `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the // endpoint state at restore time as the socket is moved to it's correct // state. - origEndpointState EndpointState `state:"nosave"` + origEndpointState uint32 `state:"nosave"` isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` @@ -468,7 +480,7 @@ type endpoint struct { // recentTSTime is the unix time when we last updated // TCPEndpointStateInner.RecentTS. - recentTSTime time.Time `state:".(unixTime)"` + recentTSTime tcpip.MonotonicTime // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -626,7 +638,7 @@ type endpoint struct { // lastOutOfWindowAckTime is the time at which the an ACK was sent in response // to an out of window segment being received by this endpoint. - lastOutOfWindowAckTime time.Time `state:".(unixTime)"` + lastOutOfWindowAckTime tcpip.MonotonicTime } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -747,7 +759,7 @@ func (e *endpoint) ResumeWork() { // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + oldstate := EndpointState(atomic.LoadUint32(&e.state)) switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() @@ -764,18 +776,18 @@ func (e *endpoint) setEndpointState(state EndpointState) { e.stack.Stats().TCP.CurrentEstablished.Decrement() } } - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { e.RecentTS = recentTS - e.recentTSTime = time.Now() + e.recentTSTime = e.stack.Clock().NowMonotonic() } // recentTimestamp returns the value of the recentTS field. @@ -806,11 +818,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue }, sndQueueInfo: sndQueueInfo{ TCPSndBufState: stack.TCPSndBufState{ - SndMTU: int(math.MaxInt32), + SndMTU: math.MaxInt32, }, }, waiterQueue: waiterQueue, - state: StateInitial, + state: uint32(StateInitial), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -870,9 +882,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.TSOffset = timeStampOffset() + e.TSOffset = timeStampOffset(e.stack.Rand()) e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) + e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker) return e } @@ -1189,7 +1201,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvQueueInfo.rcvQueueMu.Unlock() return } - now := time.Now() + now := e.stack.Clock().NowMonotonic() if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt { e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied e.rcvQueueInfo.rcvQueueMu.Unlock() @@ -1544,7 +1556,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // Add data to the send queue. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, v) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) e.sndQueueInfo.SndBufUsed += len(v) e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) e.sndQueueInfo.sndQueue.PushBack(s) @@ -1956,6 +1968,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { info := tcpip.TCPInfoOption{} e.LockUser() + if state := e.EndpointState(); state.internal() { + info.State = tcpip.EndpointState(StateClose) + } else { + info.State = tcpip.EndpointState(state) + } snd := e.snd if snd != nil { // We do not calculate RTT before sending the data packets. If @@ -2198,7 +2215,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } @@ -2224,7 +2241,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but // less than 1 second has elapsed since its recentTS was updated then // we cannot reuse the port. - if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + if tcpEP.EndpointState() != StateTimeWait || e.stack.Clock().NowMonotonic().Sub(tcpEP.recentTSTime) < 1*time.Second { tcpEP.UnlockUser() return false, nil } @@ -2245,7 +2262,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { return false, nil } } @@ -2370,7 +2387,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { } // Queue fin segment. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil) e.sndQueueInfo.sndQueue.PushBack(s) e.sndQueueInfo.SndBufInQueue++ // Mark endpoint as closed. @@ -2581,7 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) { id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2731,7 +2748,7 @@ func (e *endpoint) updateSndBufferUsage(v int) { // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1 e.sndQueueInfo.sndQueueMu.Unlock() if notify { @@ -2848,23 +2865,20 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(time.Now(), e.TSOffset) + return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { - return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 { + d := curTime.Sub(tcpip.MonotonicTime{}) + return uint32(d.Milliseconds()) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending // timestamp values in a timestamp option for a TCP segment. -func timeStampOffset() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } +func timeStampOffset(rng *rand.Rand) uint32 { // Initialize a random tsOffset that will be added to the recentTS // everytime the timestamp is sent when the Timestamp option is enabled. // @@ -2874,7 +2888,7 @@ func timeStampOffset() uint32 { // NOTE: This is not completely to spec as normally this should be // initialized in a manner analogous to how sequence numbers are // randomized per connection basis. But for now this is sufficient. - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return rng.Uint32() } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint @@ -2909,7 +2923,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { s := stack.TCPEndpointState{ TCPEndpointStateInner: e.TCPEndpointStateInner, ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID), - SegTime: time.Now(), + SegTime: e.stack.Clock().NowMonotonic(), Receiver: e.rcv.TCPReceiverState, Sender: e.snd.TCPSenderState, } @@ -2937,7 +2951,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { if cubic, ok := e.snd.cc.(*cubicState); ok { s.Sender.Cubic = cubic.TCPCubicState - s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T) + s.Sender.Cubic.TimeSinceLastCongestion = e.stack.Clock().NowMonotonic().Sub(s.Sender.Cubic.T) } s.Sender.RACKState = e.snd.rc.TCPRACKState @@ -3029,14 +3043,16 @@ func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { // allowOutOfWindowAck returns true if an out-of-window ACK can be sent now. func (e *endpoint) allowOutOfWindowAck() bool { - var limit stack.TCPInvalidRateLimitOption - if err := e.stack.Option(&limit); err != nil { - panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) - } + now := e.stack.Clock().NowMonotonic() - now := time.Now() - if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { - return false + if e.lastOutOfWindowAckTime != (tcpip.MonotonicTime{}) { + var limit stack.TCPInvalidRateLimitOption + if err := e.stack.Option(&limit); err != nil { + panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) + } + if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { + return false + } } e.lastOutOfWindowAckTime = now diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 6e9777fe4..952ccacdd 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -154,20 +154,25 @@ func (e *endpoint) afterLoad() { e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. - e.state = StateInitial + e.state = uint32(StateInitial) // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.keepalive.timer.init(s.Clock(), &e.keepalive.waker) + if snd := e.snd; snd != nil { + snd.resendTimer.init(s.Clock(), &snd.resendWaker) + snd.reorderTimer.init(s.Clock(), &snd.reorderWaker) + snd.probeTimer.init(s.Clock(), &snd.probeWaker) + } e.stack = s e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() - epState := e.origEndpointState + epState := EndpointState(e.origEndpointState) switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption @@ -281,32 +286,12 @@ func (e *endpoint) Resume(s *stack.Stack) { }() case epState == StateClose: e.isPortReserved = false - e.state = StateClose + e.state = uint32(StateClose) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) case epState == StateError: - e.state = StateError + e.state = uint32(StateError) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } - -// saveRecentTSTime is invoked by stateify. -func (e *endpoint) saveRecentTSTime() unixTime { - return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} -} - -// loadRecentTSTime is invoked by stateify. -func (e *endpoint) loadRecentTSTime(unix unixTime) { - e.recentTSTime = time.Unix(unix.second, unix.nano) -} - -// saveLastOutOfWindowAckTime is invoked by stateify. -func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { - return unixTime{e.lastOutOfWindowAckTime.Unix(), e.lastOutOfWindowAckTime.UnixNano()} -} - -// loadLastOutOfWindowAckTime is invoked by stateify. -func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { - e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2f9fe7ee0..65c86823a 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -65,7 +65,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, f.stack.Clock(), pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index a3d1aa1a3..2fc282e73 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -131,7 +131,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(ep, id, pkt) + p.dispatcher.queuePacket(ep, id, p.stack.Clock(), pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -142,14 +142,14 @@ func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEnd // particular, SYNs addressed to a non-existent connection are rejected by this // means." func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, p.stack.Clock(), pkt) defer s.decRef() if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } - if !s.flagIsSet(header.TCPFlagRst) { + if !s.flags.Contains(header.TCPFlagRst) { replyWithReset(p.stack, s, stack.DefaultTOS, 0) } @@ -181,7 +181,7 @@ func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { // reset has sequence number zero and the ACK field is set to the sum // of the sequence number and segment length of the incoming segment. // The connection remains in the CLOSED state. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { seq = s.ackNumber } else { flags |= header.TCPFlagAck @@ -401,7 +401,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er case *tcpip.TCPTimeWaitReuseOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) + *v = p.timeWaitReuse p.mu.RUnlock() return nil @@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection. recovery: 0, } - p.dispatcher.init(runtime.GOMAXPROCS(0)) + p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 9e332dcf7..0da4eafaa 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -79,7 +79,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // update will update the RACK related fields when an ACK has been received. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { - rtt := time.Now().Sub(seg.xmitTime) + rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime) tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a @@ -115,7 +115,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // ending sequence number of the packet which has been acknowledged // most recently. endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) { rc.XmitTime = seg.xmitTime rc.EndSequence = endSeq } @@ -174,7 +174,7 @@ func (s *sender) schedulePTO() { } s.rtt.Unlock() - now := time.Now() + now := s.ep.stack.Clock().NowMonotonic() if s.resendTimer.enabled() { if now.Add(pto).After(s.resendTimer.target) { pto = s.resendTimer.target.Sub(now) @@ -279,7 +279,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // been observed RACK uses reo_wnd of zero during loss recovery, in order to // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. -func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { +func (rc *rackControl) updateRACKReorderWindow() { dsackSeen := rc.DSACKSeen snd := rc.snd @@ -352,7 +352,7 @@ func (rc *rackControl) exitRecovery() { // detectLoss marks the segment as lost if the reordering window has elapsed // and the ACK is not received. It will also arm the reorder timer. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5. -func (rc *rackControl) detectLoss(rcvTime time.Time) int { +func (rc *rackControl) detectLoss(rcvTime tcpip.MonotonicTime) int { var timeout time.Duration numLost := 0 for seg := rc.snd.writeList.Front(); seg != nil && seg.xmitCount != 0; seg = seg.Next() { @@ -366,7 +366,7 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int { } endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) { timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd if timeRemaining <= 0 { seg.lost = true @@ -392,7 +392,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error { return nil } - numLost := rc.detectLoss(time.Now()) + numLost := rc.detectLoss(rc.snd.ep.stack.Clock().NowMonotonic()) if numLost == 0 { return nil } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 133371455..661ca604a 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -17,7 +17,6 @@ package tcp import ( "container/heap" "math" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,7 +49,7 @@ type receiver struct { pendingRcvdSegments segmentHeap // Time when the last ack was received. - lastRcvdAckTime time.Time `state:".(unixTime)"` + lastRcvdAckTime tcpip.MonotonicTime } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { @@ -63,7 +62,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - lastRcvdAckTime: time.Now(), + lastRcvdAckTime: ep.stack.Clock().NowMonotonic(), } } @@ -137,9 +136,9 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + if r.RcvNxt.Add(curWnd).LessThan(r.RcvNxt.Add(newWnd)) && toGrow { // If the new window moves the right edge, then update RcvAcc. - r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) + r.RcvAcc = r.RcvNxt.Add(newWnd) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -245,7 +244,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. - if s.flagIsSet(header.TCPFlagFin) { + if s.flags.Contains(header.TCPFlagFin) { r.RcvNxt++ // Send ACK immediately. @@ -261,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -296,7 +295,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -325,9 +324,9 @@ func (r *receiver) updateRTT() { // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. r.ep.rcvQueueInfo.rcvQueueMu.Lock() - if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime == (tcpip.MonotonicTime{}) { // New measurement. - r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic() r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return @@ -336,14 +335,14 @@ func (r *receiver) updateRTT() { r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) + rtt := r.ep.stack.Clock().NowMonotonic().Sub(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic() r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } @@ -424,7 +423,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { + if closed && (!s.flags.Contains(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -467,11 +466,11 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { } // Store the time of the last ack. - r.lastRcvdAckTime = time.Now() + r.lastRcvdAckTime = r.ep.stack.Clock().NowMonotonic() // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { - if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { + if segLen > 0 || s.flags.Contains(header.TCPFlagFin) { // We only store the segment if it's within our buffer size limit. // // Only use 75% of the receive buffer queue for out-of-order @@ -539,7 +538,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // // As we do not yet support PAWS, we are being conservative in ignoring // RSTs by default. - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { return false, false } @@ -559,13 +558,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { + if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } // Drop the segment if it does not contain an ACK. - if !s.flagIsSet(header.TCPFlagAck) { + if !s.flags.Contains(header.TCPFlagAck) { return false, false } @@ -574,7 +573,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flags.Contains(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go deleted file mode 100644 index 2bf21a2e7..000000000 --- a/pkg/tcpip/transport/tcp/rcv_state.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// saveLastRcvdAckTime is invoked by stateify. -func (r *receiver) saveLastRcvdAckTime() unixTime { - return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} -} - -// loadLastRcvdAckTime is invoked by stateify. -func (r *receiver) loadLastRcvdAckTime(unix unixTime) { - r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7e5ba6ef7..ca78c96f2 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -17,7 +17,6 @@ package tcp import ( "fmt" "sync/atomic" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -73,9 +72,9 @@ type segment struct { parsedOptions header.TCPOptions options []byte `state:".([]byte)"` hasNewSACKInfo bool - rcvdTime time.Time `state:".(unixTime)"` + rcvdTime tcpip.MonotonicTime // xmitTime is the last transmit time of this segment. - xmitTime time.Time `state:".(unixTime)"` + xmitTime tcpip.MonotonicTime xmitCount uint32 // acked indicates if the segment has already been SACKed. @@ -88,7 +87,7 @@ type segment struct { lost bool } -func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) *segment { netHdr := pkt.Network() s := &segment{ refCnt: 1, @@ -100,17 +99,17 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * } s.data = pkt.Data().ExtractVV().Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) - s.rcvdTime = time.Now() + s.rcvdTime = clock.NowMonotonic() s.dataMemSize = s.data.Size() return s } -func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, clock tcpip.Clock, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, } - s.rcvdTime = time.Now() + s.rcvdTime = clock.NowMonotonic() if len(v) != 0 { s.views[0] = v s.data = buffer.NewVectorisedView(len(v), s.views[:1]) @@ -149,16 +148,6 @@ func (s *segment) merge(oth *segment) { oth.dataMemSize = oth.data.Size() } -// flagIsSet checks if at least one flag in flags is set in s.flags. -func (s *segment) flagIsSet(flags header.TCPFlags) bool { - return s.flags&flags != 0 -} - -// flagsAreSet checks if all flags in flags are set in s.flags. -func (s *segment) flagsAreSet(flags header.TCPFlags) bool { - return s.flags&flags == flags -} - // setOwner sets the owning endpoint for this segment. Its required // to be called to ensure memory accounting for receive/send buffer // queues is done properly. @@ -198,10 +187,10 @@ func (s *segment) incRef() { // as the data length plus one for each of the SYN and FIN bits set. func (s *segment) logicalLen() seqnum.Size { l := seqnum.Size(s.data.Size()) - if s.flagIsSet(header.TCPFlagSyn) { + if s.flags.Contains(header.TCPFlagSyn) { l++ } - if s.flagIsSet(header.TCPFlagFin) { + if s.flags.Contains(header.TCPFlagFin) { l++ } return l @@ -243,7 +232,7 @@ func (s *segment) parse(skipChecksumValidation bool) bool { return false } - s.options = []byte(s.hdr[header.TCPMinimumSize:]) + s.options = s.hdr[header.TCPMinimumSize:] s.parsedOptions = header.ParseTCPOptions(s.options) if skipChecksumValidation { s.csumValid = true @@ -262,5 +251,5 @@ func (s *segment) parse(skipChecksumValidation bool) bool { // sackBlock returns a header.SACKBlock that represents this segment. func (s *segment) sackBlock() header.SACKBlock { - return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} + return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())} } diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go index 7422d8c02..dcfa80f95 100644 --- a/pkg/tcpip/transport/tcp/segment_state.go +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -15,8 +15,6 @@ package tcp import ( - "time" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -55,23 +53,3 @@ func (s *segment) loadOptions(options []byte) { // allocated so there is no cost here. s.options = options } - -// saveRcvdTime is invoked by stateify. -func (s *segment) saveRcvdTime() unixTime { - return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} -} - -// loadRcvdTime is invoked by stateify. -func (s *segment) loadRcvdTime(unix unixTime) { - s.rcvdTime = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (s *segment) saveXmitTime() unixTime { - return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (s *segment) loadXmitTime(unix unixTime) { - s.rcvdTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 486016fc0..2e6ea06f5 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -39,10 +40,11 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW } func TestSegmentMerge(t *testing.T) { + var clock faketime.NullClock id := stack.TransportEndpointID{} - seg1 := newOutgoingSegment(id, buffer.NewView(10)) + seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10)) defer seg1.decRef() - seg2 := newOutgoingSegment(id, buffer.NewView(20)) + seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20)) defer seg2.decRef() checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index f43e86677..72d58dcff 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -94,7 +94,7 @@ type sender struct { // firstRetransmittedSegXmitTime is the original transmit time of // the first segment that was retransmitted due to RTO expiration. - firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + firstRetransmittedSegXmitTime tcpip.MonotonicTime // zeroWindowProbing is set if the sender is currently probing // for zero receive window. @@ -169,7 +169,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint SndUna: iss + 1, SndNxt: iss + 1, RTTMeasureSeqNum: iss + 1, - LastSendTime: time.Now(), + LastSendTime: ep.stack.Clock().NowMonotonic(), MaxPayloadSize: maxPayloadSize, MaxSentAck: irs + 1, FastRecovery: stack.TCPFastRecoveryState{ @@ -197,9 +197,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.SndWndScale = uint8(sndWndScale) } - s.resendTimer.init(&s.resendWaker) - s.reorderTimer.init(&s.reorderWaker) - s.probeTimer.init(&s.probeWaker) + s.resendTimer.init(s.ep.stack.Clock(), &s.resendWaker) + s.reorderTimer.init(s.ep.stack.Clock(), &s.reorderWaker) + s.probeTimer.init(s.ep.stack.Clock(), &s.probeWaker) s.updateMaxPayloadSize(int(ep.route.MTU()), 0) @@ -441,7 +441,7 @@ func (s *sender) retransmitTimerExpired() bool { // timeout since the first retransmission. uto := s.ep.userTimeout - if s.firstRetransmittedSegXmitTime.IsZero() { + if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) { // We store the original xmitTime of the segment that we are // about to retransmit as the retransmission time. This is // required as by the time the retransmitTimer has expired the @@ -450,7 +450,7 @@ func (s *sender) retransmitTimerExpired() bool { s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime } - elapsed := time.Since(s.firstRetransmittedSegXmitTime) + elapsed := s.ep.stack.Clock().NowMonotonic().Sub(s.firstRetransmittedSegXmitTime) remaining := s.maxRTO if uto != 0 { // Cap to the user specified timeout if one is specified. @@ -616,7 +616,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // 'S2' that meets the following 3 criteria for determinig // loss, the sequence range of one segment of up to SMSS // octects starting with S2 MUST be returned. - if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{Start: segSeq, End: segSeq.Add(1)}) { // NextSeg(): // // (1.a) S2 is greater than HighRxt @@ -866,8 +866,8 @@ func (s *sender) enableZeroWindowProbing() { // We piggyback the probing on the retransmit timer with the // current retranmission interval, as we may start probing while // segment retransmissions. - if s.firstRetransmittedSegXmitTime.IsZero() { - s.firstRetransmittedSegXmitTime = time.Now() + if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) { + s.firstRetransmittedSegXmitTime = s.ep.stack.Clock().NowMonotonic() } s.resendTimer.enable(s.RTO) } @@ -875,7 +875,7 @@ func (s *sender) enableZeroWindowProbing() { func (s *sender) disableZeroWindowProbing() { s.zeroWindowProbing = false s.unackZeroWindowProbes = 0 - s.firstRetransmittedSegXmitTime = time.Time{} + s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{} s.resendTimer.disable() } @@ -925,7 +925,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO { + if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && s.ep.stack.Clock().NowMonotonic().Sub(s.LastSendTime) > s.RTO { if s.SndCwnd > InitialCwnd { s.SndCwnd = InitialCwnd } @@ -1024,7 +1024,7 @@ func (s *sender) SetPipe() { if segEnd.LessThan(endSeq) { endSeq = segEnd } - sb := header.SACKBlock{startSeq, endSeq} + sb := header.SACKBlock{Start: startSeq, End: endSeq} // SetPipe(): // // After initializing pipe to zero, the following steps are @@ -1132,7 +1132,7 @@ func (s *sender) isDupAck(seg *segment) bool { // (b) The incoming acknowledgment carries no data. seg.logicalLen() == 0 && // (c) The SYN and FIN bits are both off. - !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && + !seg.flags.Intersects(header.TCPFlagFin|header.TCPFlagSyn) && // (d) the ACK number is equal to the greatest acknowledgment received on // the given connection (TCP.UNA from RFC793). seg.ackNumber == s.SndUna && @@ -1234,7 +1234,7 @@ func checkDSACK(rcvdSeg *segment) bool { func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { - s.updateRTO(time.Now().Sub(s.RTTMeasureTime)) + s.updateRTO(s.ep.stack.Clock().NowMonotonic().Sub(s.RTTMeasureTime)) s.RTTMeasureSeqNum = s.SndNxt } @@ -1444,7 +1444,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if s.SndUna == s.SndNxt { s.Outstanding = 0 // Reset firstRetransmittedSegXmitTime to the zero value. - s.firstRetransmittedSegXmitTime = time.Time{} + s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{} s.resendTimer.disable() s.probeTimer.disable() } @@ -1455,7 +1455,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: // * Step 4: Update RACK reordering window - s.rc.updateRACKReorderWindow(rcvdSeg) + s.rc.updateRACKReorderWindow() // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the @@ -1502,7 +1502,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } } - seg.xmitTime = time.Now() + seg.xmitTime = s.ep.stack.Clock().NowMonotonic() seg.xmitCount++ seg.lost = false err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) @@ -1527,7 +1527,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { - s.LastSendTime = time.Now() + s.LastSendTime = s.ep.stack.Clock().NowMonotonic() if seq == s.RTTMeasureSeqNum { s.RTTMeasureTime = s.LastSendTime } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go deleted file mode 100644 index 2f805d8ce..000000000 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// +stateify savable -type unixTime struct { - second int64 - nano int64 -} - -// afterLoad is invoked by stateify. -func (s *sender) afterLoad() { - s.resendTimer.init(&s.resendWaker) - s.reorderTimer.init(&s.reorderWaker) - s.probeTimer.init(&s.probeWaker) -} - -// saveFirstRetransmittedSegXmitTime is invoked by stateify. -func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { - return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} -} - -// loadFirstRetransmittedSegXmitTime is invoked by stateify. -func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { - s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index c58361bc1..d6cf786a1 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -38,7 +38,7 @@ const ( func setStackRACKPermitted(t *testing.T, c *context.Context) { t.Helper() - opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection) + opt := tcpip.TCPRACKLossDetection if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) } @@ -50,7 +50,7 @@ func TestRACKUpdate(t *testing.T) { c := context.New(t, uint32(mtu)) defer c.Cleanup() - var xmitTime time.Time + var xmitTime tcpip.MonotonicTime probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { // Validate that the endpoint Sender.RACKState is what we expect. @@ -79,7 +79,7 @@ func TestRACKUpdate(t *testing.T) { } // Write the data. - xmitTime = time.Now() + xmitTime = c.Stack().Clock().NowMonotonic() var r bytes.Reader r.Reset(data) if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9916182e3..e7ede7662 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4259,7 +4259,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(iss), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -5335,7 +5335,7 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next-1)), + checker.TCPSeqNum(next-1), checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), @@ -5360,12 +5360,7 @@ func TestKeepalive(t *testing.T) { }) checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), + checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), ) if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { @@ -5507,7 +5502,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + uint16(lastPortOffset), + SrcPort: context.TestPort + lastPortOffset, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(context.TestInitialSequenceNumber), @@ -5884,7 +5879,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { r.Reset(data) newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) + tcp = header.IPv4(pkt).Payload() if string(tcp.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } @@ -6118,7 +6113,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { } pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + tcpHdr = header.IPv4(pkt).Payload() if string(tcpHdr.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) } @@ -6375,7 +6370,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Allocate a large enough payload for the test. payloadSize := receiveBufferSize * 2 - b := make([]byte, int(payloadSize)) + b := make([]byte, payloadSize) worker := (c.EP).(interface { StopWork() @@ -6429,7 +6424,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // ack, 1 for the non-zero window p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(uint32(wantAckNum)), + checker.TCPAckNum(wantAckNum), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { @@ -6484,14 +6479,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := uint32(rawEP.TSVal) + tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) + return uint16(rcvWnd >> c.WindowScale) } // Allocate a large array to send to the endpoint. b := make([]byte, receiveBufferSize*48) @@ -6619,19 +6614,16 @@ func TestDelayEnabled(t *testing.T) { defer c.Cleanup() checkDelayOption(t, c, false, false) // Delay is disabled by default. - for _, v := range []struct { - delayEnabled tcpip.TCPDelayEnabled - wantDelayOption bool - }{ - {delayEnabled: false, wantDelayOption: false}, - {delayEnabled: true, wantDelayOption: true}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + for _, delayEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + opt := tcpip.TCPDelayEnabled(delayEnabled) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) + } + checkDelayOption(t, c, opt, delayEnabled) + }) } } @@ -7042,7 +7034,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Receive the SYN-ACK reply. b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) + tcpHdr = header.IPv4(b).Payload() c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders = &context.Headers{ @@ -7467,7 +7459,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), + checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), @@ -7545,7 +7537,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: iss, - AckNum: seqnum.Value(c.IRS + 1), + AckNum: c.IRS + 1, RcvWnd: 30000, }) diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 38a335840..5645c772e 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -15,21 +15,29 @@ package tcp import ( + "math" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" ) type timerState int const ( + // The timer is disabled. timerStateDisabled timerState = iota + // The timer is enabled, but the clock timer may be set to an earlier + // expiration time due to a previous orphaned state. timerStateEnabled + // The timer is disabled, but the clock timer is enabled, which means that + // it will cause a spurious wakeup unless the timer is enabled before the + // clock timer fires. timerStateOrphaned ) // timer is a timer implementation that reduces the interactions with the -// runtime timer infrastructure by letting timers run (and potentially +// clock timer infrastructure by letting timers run (and potentially // eventually expire) even if they are stopped. It makes it cheaper to // disable/reenable timers at the expense of spurious wakes. This is useful for // cases when the same timer is disabled/reenabled repeatedly with relatively @@ -39,44 +47,37 @@ const ( // (currently at least 200ms), and get disabled when acks are received, and // reenabled when new pending segments are sent. // -// It is advantageous to avoid interacting with the runtime because it acquires +// It is advantageous to avoid interacting with the clock because it acquires // a global mutex and performs O(log n) operations, where n is the global number // of timers, whenever a timer is enabled or disabled, and may make a syscall. // // This struct is thread-compatible. type timer struct { - // state is the current state of the timer, it can be one of the - // following values: - // disabled - the timer is disabled. - // orphaned - the timer is disabled, but the runtime timer is - // enabled, which means that it will evetually cause a - // spurious wake (unless it gets enabled again before - // then). - // enabled - the timer is enabled, but the runtime timer may be set - // to an earlier expiration time due to a previous - // orphaned state. state timerState + clock tcpip.Clock + // target is the expiration time of the current timer. It is only // meaningful in the enabled state. - target time.Time + target tcpip.MonotonicTime - // runtimeTarget is the expiration time of the runtime timer. It is + // clockTarget is the expiration time of the clock timer. It is // meaningful in the enabled and orphaned states. - runtimeTarget time.Time + clockTarget tcpip.MonotonicTime - // timer is the runtime timer used to wait on. - timer *time.Timer + // timer is the clock timer used to wait on. + timer tcpip.Timer } // init initializes the timer. Once it expires, it the given waker will be // asserted. -func (t *timer) init(w *sleep.Waker) { +func (t *timer) init(clock tcpip.Clock, w *sleep.Waker) { t.state = timerStateDisabled + t.clock = clock - // Initialize a runtime timer that will assert the waker, then + // Initialize a clock timer that will assert the waker, then // immediately stop it. - t.timer = time.AfterFunc(time.Hour, func() { + t.timer = t.clock.AfterFunc(math.MaxInt64, func() { w.Assert() }) t.timer.Stop() @@ -106,9 +107,9 @@ func (t *timer) checkExpiration() bool { // The timer is enabled, but it may have expired early. Check if that's // the case, and if so, reset the runtime timer to the correct time. - now := time.Now() + now := t.clock.NowMonotonic() if now.Before(t.target) { - t.runtimeTarget = t.target + t.clockTarget = t.target t.timer.Reset(t.target.Sub(now)) return false } @@ -134,11 +135,11 @@ func (t *timer) enabled() bool { // enable enables the timer, programming the runtime timer if necessary. func (t *timer) enable(d time.Duration) { - t.target = time.Now().Add(d) + t.target = t.clock.NowMonotonic().Add(d) // Check if we need to set the runtime timer. - if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { - t.runtimeTarget = t.target + if t.state == timerStateDisabled || t.target.Before(t.clockTarget) { + t.clockTarget = t.target t.timer.Reset(d) } diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go index dbd6dff54..479752de7 100644 --- a/pkg/tcpip/transport/tcp/timer_test.go +++ b/pkg/tcpip/transport/tcp/timer_test.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) func TestCleanup(t *testing.T) { @@ -27,9 +28,11 @@ func TestCleanup(t *testing.T) { isAssertedTimeoutSeconds = timerDurationSeconds + 1 ) + clock := faketime.NewManualClock() + tmr := timer{} w := sleep.Waker{} - tmr.init(&w) + tmr.init(clock, &w) tmr.enable(timerDurationSeconds * time.Second) tmr.cleanup() @@ -39,7 +42,7 @@ func TestCleanup(t *testing.T) { // The waker should not be asserted. for i := 0; i < isAssertedTimeoutSeconds; i++ { - time.Sleep(time.Second) + clock.Advance(time.Second) if w.IsAsserted() { t.Fatalf("waker asserted unexpectedly") } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index dd5c910ae..cdc344ab7 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -49,6 +49,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f7dd50d35..b964e446b 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -17,6 +17,7 @@ package udp import ( "io" "sync/atomic" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -34,19 +35,20 @@ type udpPacket struct { destinationAddress tcpip.FullAddress packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + receivedAt time.Time `state:".(int64)"` // tos stores either the receiveTOS or receiveTClass value. tos uint8 } // EndpointState represents the state of a UDP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - StateInitial EndpointState = iota + _ EndpointState = iota + StateInitial StateBound StateConnected StateClosed @@ -98,7 +100,7 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState + state uint32 route *stack.Route `state:"manual"` dstPort uint16 ttl uint8 @@ -176,7 +178,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), - state: StateInitial, + state: uint32(StateInitial), uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) @@ -204,12 +206,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState() returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -290,7 +292,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -320,7 +322,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.timestamp, + Timestamp: p.receivedAt.UnixNano(), } if e.ops.GetReceiveTOS() { cm.HasTOS = true @@ -801,8 +803,8 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ - e.multicastNICID, - e.multicastAddr, + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, } e.mu.Unlock() @@ -1075,7 +1077,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, nil /* testPort */) + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */) if err != nil { return id, bindToDevice, err } @@ -1301,12 +1303,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.LocalAddress, - Port: header.UDP(hdr).DestinationPort(), + Port: hdr.DestinationPort(), }, data: pkt.Data().ExtractVV(), } @@ -1328,7 +1330,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB packet.packetInfo.LocalAddr = localAddr packet.packetInfo.DestinationAddr = localAddr packet.packetInfo.NIC = pkt.NICID - packet.timestamp = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 4aba68b21..1f638c3f6 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,26 +15,38 @@ package udp import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *udpPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *udpPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves udpPacket.data field. -func (u *udpPacket) saveData() buffer.VectorisedView { - // We cannot save u.data directly as u.data.views may alias to u.views, +func (p *udpPacket) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, // which is not allowed by state framework (in-struct pointer). - return u.data.Clone(nil) + return p.data.Clone(nil) } // loadData loads udpPacket.data field. -func (u *udpPacket) loadData(data buffer.VectorisedView) { - // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization +func (p *udpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization // here because data.views is not guaranteed to be loaded by now. Plus, // data.views will be allocated anyway so there really is little point - // of utilizing u.views for data.views. - u.data = data + // of utilizing p.views for data.views. + p.data = data } // afterLoad is invoked by stateify. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 705ad1f64..7c357cb09 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = StateConnected + ep.state = uint32(StateConnected) ep.rcvMu.Lock() ep.rcvReady = true diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index dc2e3f493..4008cacf2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,17 +16,16 @@ package udp_test import ( "bytes" - "context" "fmt" "io/ioutil" "math/rand" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -298,16 +297,18 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: true, - }) + return newDualTestContextWithHandleLocal(t, mtu, true) } -func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { +func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext { t.Helper() + options := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: handleLocal, + Clock: &faketime.NullClock{}, + } s := stack.New(options) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -378,9 +379,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { c.t.Fatalf("Packet wasn't written out") return nil @@ -534,7 +533,9 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -606,7 +607,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe case <-ch: res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - case <-time.After(300 * time.Millisecond): + default: if packetShouldBeDropped { return // expected to time out } @@ -820,11 +821,7 @@ func TestV4ReadSelfSource(t *testing.T) { {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, } { t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: tt.handleLocal, - }) + c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal) defer c.cleanup() c.createEndpointForFlow(unicastV4) @@ -1034,17 +1031,17 @@ func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, che payload := testWriteNoVerify(c, flow, setDest) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP + var udpH header.UDP if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) + udpH = header.IPv4(b).Payload() } else { - udp = header.UDP(header.IPv6(b).Payload()) + udpH = header.IPv6(b).Payload() } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + if !bytes.Equal(payload, udpH.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload) } - return udp.SourcePort() + return udpH.SourcePort() } func testDualWrite(c *testContext) uint16 { @@ -1198,7 +1195,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { r.Reset(payload) n, err := c.ep.Write(&r, writeOpts) if err != nil { - c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) + c.t.Fatalf("c.ep.Write(...) = %s, want nil", err) } if got, want := n, int64(len(payload)); got != want { c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) @@ -1462,7 +1459,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { name: "IPv4 unicast", proto: header.IPv4ProtocolNumber, flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, }, { name: "IPv4 multicast", @@ -1474,7 +1471,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, }, { name: "IPv4 broadcast", @@ -1486,13 +1483,13 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, }, { name: "IPv6 unicast", proto: header.IPv6ProtocolNumber, flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, }, { name: "IPv6 multicast", @@ -1504,7 +1501,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, }, } @@ -1614,6 +1611,7 @@ func TestTTL(t *testing.T) { } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{p}, + Clock: &faketime.NullClock{}, }) ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) wantTTL = ep.DefaultTTL() @@ -1759,7 +1757,7 @@ func TestReceiveTosTClass(t *testing.T) { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) } - want := true + const want = true optionSetter(want) got := optionGetter() @@ -1889,18 +1887,14 @@ func TestV4UnknownDestination(t *testing.T) { } } if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { + if p, ok := c.linkEP.Read(); ok { t.Fatalf("unexpected packet received: %+v", p) } return } // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { t.Fatalf("packet wasn't written out") return @@ -1987,18 +1981,14 @@ func TestV6UnknownDestination(t *testing.T) { } } if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { + if p, ok := c.linkEP.Read(); ok { t.Fatalf("unexpected packet received: %+v", p) } return } // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { t.Fatalf("packet wasn't written out") return @@ -2115,8 +2105,8 @@ func TestShortHeader(t *testing.T) { Data: buf.ToVectorisedView(), })) - if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) } } @@ -2124,25 +2114,27 @@ func TestShortHeader(t *testing.T) { // global and endpoint stats are incremented. func TestBadChecksumErrors(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV6} { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } + c.createEndpoint(flow.sockProto()) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) + payload := newPayload() + c.injectPacket(flow, payload, true /* badChecksum */) - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } + }) } } @@ -2484,6 +2476,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { |