diff options
128 files changed, 2445 insertions, 1360 deletions
diff --git a/.travis.yml b/.travis.yml index 9d3141f38..1d955b05d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -30,8 +30,10 @@ services: - docker jobs: include: - - os: linux - arch: amd64 + # AMD64 builds are tested on kokoro, so don't run them in travis to save + # capacity for arm64 builds. + # - os: linux + # arch: amd64 - os: linux arch: arm64 script: @@ -166,11 +166,14 @@ do-tests: runsc simple-tests: unit-tests # Compatibility target. .PHONY: simple-tests +IMAGE_FILTER := HelloWorld\|Httpd\|Ruby\|Stdio +INTEGRATION_FILTER := Life\|Pause\|Connect\|JobControl\|Overlay\|Exec\|DirCreation/root + docker-tests: load-basic-images @$(call submake,install-test-runtime RUNTIME="vfs1") @$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)") @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2") - @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_filter=.*TestHelloWorld" TARGETS="$(INTEGRATION_TARGETS)") + @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_filter=$(IMAGE_FILTER)\|$(INTEGRATION_FILTER)" TARGETS="$(INTEGRATION_TARGETS)") .PHONY: docker-tests overlay-tests: load-basic-images @@ -336,10 +339,10 @@ RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND% dev: ## Installs a set of local runtimes. Requires sudo. @$(call submake,refresh ARGS="--net-raw") - @$(call submake,configure RUNTIME="$(RUNTIME)" ARGS="--net-raw") - @$(call submake,configure RUNTIME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets") - @$(call submake,configure RUNTIME="$(RUNTIME)-p" ARGS="--net-raw --profile") - @$(call submake,configure RUNTIME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="--net-raw") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-p" ARGS="--net-raw --profile") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2") @sudo systemctl restart docker .PHONY: dev @@ -350,8 +353,8 @@ refresh: ## Refreshes the runtime binary (for development only). Must have calle install-test-runtime: ## Installs the runtime for testing. Requires sudo. @$(call submake,refresh ARGS="--net-raw --TESTONLY-test-name-env=RUNSC_TEST_NAME --debug --strace --log-packets $(ARGS)") - @$(call submake,configure RUNTIME=runsc) - @$(call submake,configure) + @$(call submake,configure RUNTIME_NAME=runsc) + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)") @sudo systemctl restart docker @if [[ -f /etc/docker/daemon.json ]]; then \ sudo chmod 0755 /etc/docker && \ @@ -360,7 +363,7 @@ install-test-runtime: ## Installs the runtime for testing. Requires sudo. .PHONY: install-test-runtime configure: ## Configures a single runtime. Requires sudo. Typically called from dev or install-test-runtime. - @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS) + @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME_NAME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS) @echo -e "$(INFO) Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)" @echo -e "$(INFO) Logs are in: $(RUNTIME_LOG_DIR)" @sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)" diff --git a/images/benchmarks/ffmpeg/Dockerfile b/images/benchmarks/ffmpeg/Dockerfile new file mode 100644 index 000000000..7108df64f --- /dev/null +++ b/images/benchmarks/ffmpeg/Dockerfile @@ -0,0 +1,9 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* +WORKDIR /media +ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4 diff --git a/images/benchmarks/redis/Dockerfile b/images/benchmarks/redis/Dockerfile new file mode 100644 index 000000000..0f17249af --- /dev/null +++ b/images/benchmarks/redis/Dockerfile @@ -0,0 +1 @@ +FROM redis:5.0.4 diff --git a/images/default/Dockerfile b/images/default/Dockerfile index 397082b02..2b38e6c58 100644 --- a/images/default/Dockerfile +++ b/images/default/Dockerfile @@ -1,7 +1,7 @@ FROM fedora:31 # Install bazel. RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel -RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch +RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch diffutils RUN pip install pycparser RUN dnf install -y bazel3 # Install gcloud. diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 789369220..5fb419bcd 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -8,7 +8,6 @@ go_template_instance( out = "dirty_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "Dirty", @@ -25,14 +24,14 @@ go_template_instance( name = "frame_ref_set_impl", out = "frame_ref_set_impl.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "fsutil", prefix = "FrameRef", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "uint64", "Functions": "FrameRefSetFunctions", }, @@ -43,7 +42,6 @@ go_template_instance( out = "file_range_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "FileRange", @@ -86,7 +84,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/state", diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go index c6cd45087..2c9446c1d 100644 --- a/pkg/sentry/fs/fsutil/dirty_set.go +++ b/pkg/sentry/fs/fsutil/dirty_set.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) { // repeatedly until all bytes have been written. max is the true size of the // cached object; offsets beyond max will not be passed to writeAt, even if // they are marked dirty. -func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { var changedDirty bool defer func() { if changedDirty { @@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet // successful partial write, SyncDirtyAll will call it repeatedly until all // bytes have been written. max is the true size of the cached object; offsets // beyond max will not be passed to writeAt, even if they are marked dirty. -func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { dseg := dirty.FirstSegment() for dseg.Ok() { if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil { @@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max } // Preconditions: mr must be page-aligned. -func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() { wbr := cseg.Range().Intersect(mr) if max < wbr.Start { diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index 5643cdac9..bbafebf03 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -23,13 +23,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/usermem" ) // FileRangeSet maps offsets into a memmap.Mappable to offsets into a -// platform.File. It is used to implement Mappables that store data in +// memmap.File. It is used to implement Mappables that store data in // sparsely-allocated memory. // // type FileRangeSet <generated by go_generics> @@ -65,20 +64,20 @@ func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, spli } // FileRange returns the FileRange mapped by seg. -func (seg FileRangeIterator) FileRange() platform.FileRange { +func (seg FileRangeIterator) FileRange() memmap.FileRange { return seg.FileRangeOf(seg.Range()) } // FileRangeOf returns the FileRange mapped by mr. // // Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0. -func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange { +func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange { frstart := seg.Value() + (mr.Start - seg.Start()) - return platform.FileRange{frstart, frstart + mr.Length()} + return memmap.FileRange{frstart, frstart + mr.Length()} } // Fill attempts to ensure that all memmap.Mappable offsets in required are -// mapped to a platform.File offset, by allocating from mf with the given +// mapped to a memmap.File offset, by allocating from mf with the given // memory usage kind and invoking readAt to store data into memory. (If readAt // returns a successful partial read, Fill will call it repeatedly until all // bytes have been read.) EOF is handled consistently with the requirements of @@ -141,7 +140,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map } // Drop removes segments for memmap.Mappable offsets in mr, freeing the -// corresponding platform.FileRanges. +// corresponding memmap.FileRanges. // // Preconditions: mr must be page-aligned. func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { @@ -154,7 +153,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { } // DropAll removes all segments in mr, freeing the corresponding -// platform.FileRanges. +// memmap.FileRanges. func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) { for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { mf.DecRef(seg.FileRange()) diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go index dd6f5aba6..a808894df 100644 --- a/pkg/sentry/fs/fsutil/frame_ref_set.go +++ b/pkg/sentry/fs/fsutil/frame_ref_set.go @@ -17,7 +17,7 @@ package fsutil import ( "math" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" ) @@ -39,7 +39,7 @@ func (FrameRefSetFunctions) ClearValue(val *uint64) { } // Merge implements segment.Functions.Merge. -func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) { +func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) { if val1 != val2 { return 0, false } @@ -47,13 +47,13 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform. } // Split implements segment.Functions.Split. -func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) { +func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) { return val, val } // IncRefAndAccount adds a reference on the range fr. All newly inserted segments // are accounted as host page cache memory mappings. -func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) { seg, gap := refs.Find(fr.Start) for { switch { @@ -74,7 +74,7 @@ func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { // DecRefAndAccount removes a reference on the range fr and untracks segments // that are removed from memory accounting. -func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) { seg := refs.FindSegment(fr.Start) for seg.Ok() && seg.Start() < fr.End { diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index e82afd112..ef0113b52 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -126,7 +125,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { // offsets in fr or until the next call to UnmapAll. // // Preconditions: The caller must hold a reference on all offsets in fr. -func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) { +func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) { chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift) f.mapsMu.Lock() defer f.mapsMu.Unlock() @@ -146,7 +145,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) } // Preconditions: f.mapsMu must be locked. -func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error { +func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error { prot := syscall.PROT_READ if write { prot |= syscall.PROT_WRITE diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 78fec553e..c15d8a946 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -21,18 +21,17 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// HostMappable implements memmap.Mappable and platform.File over a +// HostMappable implements memmap.Mappable and memmap.File over a // CachedFileObject. // // Lock order (compare the lock order model in mm/mm.go): // truncateMu ("fs locks") // mu ("memmap.Mappable locks not taken by Translate") -// ("platform.File locks") +// ("memmap.File locks") // backingFile ("CachedFileObject locks") // // +stateify savable @@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error { return nil } -// MapInternal implements platform.File.MapInternal. -func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (h *HostMappable) FD() int { return h.backingFile.FD() } -// IncRef implements platform.File.IncRef. -func (h *HostMappable) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (h *HostMappable) IncRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.IncRefOn(mr) } -// DecRef implements platform.File.DecRef. -func (h *HostMappable) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (h *HostMappable) DecRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.DecRefOn(mr) } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 800c8b4e1..fe8b0b6ac 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -26,7 +26,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -934,7 +933,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. c.mapsMu.Lock() @@ -999,10 +998,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable } } -// IncRef implements platform.File.IncRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// IncRef implements memmap.File.IncRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { +func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg, gap := c.refs.Find(fr.Start) @@ -1024,10 +1023,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { } } -// DecRef implements platform.File.DecRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// DecRef implements memmap.File.DecRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { +func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg := c.refs.FindSegment(fr.Start) @@ -1046,15 +1045,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { c.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. This is used when we +// MapInternal implements memmap.File.MapInternal. This is used when we // directly map an underlying host fd and CachingInodeOperations is used as the -// platform.File during translation. -func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// memmap.File during translation. +func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// FD implements memmap.File.FD. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. func (c *CachingInodeOperations) FD() int { return c.backingFile.FD() diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 02317a133..09f142cfc 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -221,12 +220,12 @@ func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequenc return 0, syserror.EINVAL } mr := memmap.MappableRange{pgstart, pgend} - var freed []platform.FileRange + var freed []memmap.FileRange d.dataMu.Lock() cseg := d.cache.LowerBoundSegment(mr.Start) for cseg.Ok() && cseg.Start() < mr.End { cseg = d.cache.Isolate(cseg, mr) - freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) + freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) cseg = d.cache.Remove(cseg).NextSegment() } d.dataMu.Unlock() @@ -821,7 +820,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (d *dentry) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. d.mapsMu.Lock() @@ -869,8 +868,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { } } -// dentryPlatformFile implements platform.File. It exists solely because dentry -// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef. +// dentryPlatformFile implements memmap.File. It exists solely because dentry +// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef. // // dentryPlatformFile is only used when a host FD representing the remote file // is available (i.e. dentry.handle.fd >= 0), and that FD is used for @@ -878,7 +877,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { type dentryPlatformFile struct { *dentry - // fdRefs counts references on platform.File offsets. fdRefs is protected + // fdRefs counts references on memmap.File offsets. fdRefs is protected // by dentry.dataMu. fdRefs fsutil.FrameRefSet @@ -890,29 +889,29 @@ type dentryPlatformFile struct { hostFileMapperInitOnce sync.Once } -// IncRef implements platform.File.IncRef. -func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.IncRefAndAccount(fr) d.dataMu.Unlock() } -// DecRef implements platform.File.DecRef. -func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.DecRefAndAccount(fr) d.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. -func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write) d.handleMu.RUnlock() return bs, err } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (d *dentryPlatformFile) FD() int { d.handleMu.RLock() fd := d.handle.fd diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index e86fbe2d5..bd701bbc7 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -34,7 +34,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/platform", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go index 8545a82f0..65d3af38c 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -19,13 +19,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// inodePlatformFile implements platform.File. It exists solely because inode -// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef. +// inodePlatformFile implements memmap.File. It exists solely because inode +// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // // inodePlatformFile should only be used if inode.canMap is true. type inodePlatformFile struct { @@ -34,7 +33,7 @@ type inodePlatformFile struct { // fdRefsMu protects fdRefs. fdRefsMu sync.Mutex - // fdRefs counts references on platform.File offsets. It is used solely for + // fdRefs counts references on memmap.File offsets. It is used solely for // memory accounting. fdRefs fsutil.FrameRefSet @@ -45,32 +44,32 @@ type inodePlatformFile struct { fileMapperInitOnce sync.Once } -// IncRef implements platform.File.IncRef. +// IncRef implements memmap.File.IncRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) IncRef(fr platform.FileRange) { +func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.IncRefAndAccount(fr) i.fdRefsMu.Unlock() } -// DecRef implements platform.File.DecRef. +// DecRef implements memmap.File.DecRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) DecRef(fr platform.FileRange) { +func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.DecRefAndAccount(fr) i.fdRefsMu.Unlock() } -// MapInternal implements platform.File.MapInternal. +// MapInternal implements memmap.File.MapInternal. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (i *inodePlatformFile) FD() int { return i.hostFD } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index bfd779837..c211fc8d0 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -20,7 +20,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index f66cfcc7f..55b4c2cdb 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -45,7 +45,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -370,7 +369,7 @@ type Shm struct { // fr is the offset into mfp.MemoryFile() that backs this contents of this // segment. Immutable. - fr platform.FileRange + fr memmap.FileRange // mu protects all fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 5f3908d8b..7c4fefb16 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/log" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" ) @@ -90,7 +90,7 @@ type Timekeeper struct { // NewTimekeeper does not take ownership of paramPage. // // SetClocks must be called on the returned Timekeeper before it is usable. -func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) { +func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) { return &Timekeeper{ params: NewVDSOParamPage(mfp, paramPage), }, nil diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index f1b3c212c..290c32466 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -58,7 +58,7 @@ type vdsoParams struct { type VDSOParamPage struct { // The parameter page is fr, allocated from mfp.MemoryFile(). mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange // seq is the current sequence count written to the page. // @@ -81,7 +81,7 @@ type VDSOParamPage struct { // * VDSOParamPage must be the only writer to fr. // // * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block. -func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage { +func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage { return &VDSOParamPage{mfp: mfp, fr: fr} } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index a98b66de1..2c95669cd 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -28,9 +28,21 @@ go_template_instance( }, ) +go_template_instance( + name = "file_range", + out = "file_range.go", + package = "memmap", + prefix = "File", + template = "//pkg/segment:generic_range", + types = { + "T": "uint64", + }, +) + go_library( name = "memmap", srcs = [ + "file_range.go", "mappable_range.go", "mapping_set.go", "mapping_set_impl.go", @@ -40,7 +52,7 @@ go_library( deps = [ "//pkg/context", "//pkg/log", - "//pkg/sentry/platform", + "//pkg/safemem", "//pkg/syserror", "//pkg/usermem", ], diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index c6db9fc8f..c188f6c29 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -19,12 +19,12 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/usermem" ) // Mappable represents a memory-mappable object, a mutable mapping from uint64 -// offsets to (platform.File, uint64 File offset) pairs. +// offsets to (File, uint64 File offset) pairs. // // See mm/mm.go for Mappable's place in the lock order. // @@ -74,7 +74,7 @@ type Mappable interface { // Translations are valid until invalidated by a callback to // MappingSpace.Invalidate or until the caller removes its mapping of the // translated range. Mappable implementations must ensure that at least one - // reference is held on all pages in a platform.File that may be the result + // reference is held on all pages in a File that may be the result // of a valid Translation. // // Preconditions: required.Length() > 0. optional.IsSupersetOf(required). @@ -100,7 +100,7 @@ type Translation struct { Source MappableRange // File is the mapped file. - File platform.File + File File // Offset is the offset into File at which this Translation begins. Offset uint64 @@ -110,9 +110,9 @@ type Translation struct { Perms usermem.AccessType } -// FileRange returns the platform.FileRange represented by t. -func (t Translation) FileRange() platform.FileRange { - return platform.FileRange{t.Offset, t.Offset + t.Source.Length()} +// FileRange returns the FileRange represented by t. +func (t Translation) FileRange() FileRange { + return FileRange{t.Offset, t.Offset + t.Source.Length()} } // CheckTranslateResult returns an error if (ts, terr) does not satisfy all @@ -361,3 +361,49 @@ type MMapOpts struct { // TODO(jamieliu): Replace entirely with MappingIdentity? Hint string } + +// File represents a host file that may be mapped into an platform.AddressSpace. +type File interface { + // All pages in a File are reference-counted. + + // IncRef increments the reference count on all pages in fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. At least one reference must be held on all pages in fr. (The File + // interface does not provide a way to acquire an initial reference; + // implementors may define mechanisms for doing so.) + IncRef(fr FileRange) + + // DecRef decrements the reference count on all pages in fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. At least one reference must be held on all pages in fr. + DecRef(fr FileRange) + + // MapInternal returns a mapping of the given file offsets in the invoking + // process' address space for reading and writing. + // + // Note that fr.Start and fr.End need not be page-aligned. + // + // Preconditions: fr.Length() > 0. At least one reference must be held on + // all pages in fr. + // + // Postconditions: The returned mapping is valid as long as at least one + // reference is held on the mapped pages. + MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) + + // FD returns the file descriptor represented by the File. + // + // The only permitted operation on the returned file descriptor is to map + // pages from it consistent with the requirements of AddressSpace.MapFile. + FD() int +} + +// FileRange represents a range of uint64 offsets into a File. +// +// type FileRange <generated using go_generics> + +// String implements fmt.Stringer.String. +func (fr FileRange) String() string { + return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index a036ce53c..f9d0837a1 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -7,14 +7,14 @@ go_template_instance( name = "file_refcount_set", out = "file_refcount_set.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "mm", prefix = "fileRefcount", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "int32", "Functions": "fileRefcountSetFunctions", }, diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 379148903..1999ec706 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -243,7 +242,7 @@ type aioMappable struct { refs.AtomicRefCount mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange } var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp()) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 6db7c3d40..3e85964e4 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -25,7 +25,7 @@ // Locks taken by memmap.Mappable.Translate // mm.privateRefs.mu // platform.AddressSpace locks -// platform.File locks +// memmap.File locks // mm.aioManager.mu // mm.AIOContext.mu // @@ -396,7 +396,7 @@ type pma struct { // file is the file mapped by this pma. Only pmas for which file == // MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to // the corresponding file range while they exist. - file platform.File `state:"nosave"` + file memmap.File `state:"nosave"` // off is the offset into file at which this pma begins. // @@ -436,7 +436,7 @@ type pma struct { private bool // If internalMappings is not empty, it is the cached return value of - // file.MapInternal for the platform.FileRange mapped by this pma. + // file.MapInternal for the memmap.FileRange mapped by this pma. internalMappings safemem.BlockSeq `state:"nosave"` } @@ -469,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 { func (fileRefcountSetFunctions) ClearValue(_ *int32) { } -func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) { +func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) { return rc1, rc1 == rc2 } -func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) { +func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) { return rc, rc } diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 62e4c20af..930ec895f 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -604,7 +603,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat } } -// Pin returns the platform.File ranges currently mapped by addresses in ar in +// Pin returns the memmap.File ranges currently mapped by addresses in ar in // mm, acquiring a reference on the returned ranges which the caller must // release by calling Unpin. If not all addresses are mapped, Pin returns a // non-nil error. Note that Pin may return both a non-empty slice of @@ -674,15 +673,15 @@ type PinnedRange struct { Source usermem.AddrRange // File is the mapped file. - File platform.File + File memmap.File // Offset is the offset into File at which this PinnedRange begins. Offset uint64 } -// FileRange returns the platform.File offsets mapped by pr. -func (pr PinnedRange) FileRange() platform.FileRange { - return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} +// FileRange returns the memmap.File offsets mapped by pr. +func (pr PinnedRange) FileRange() memmap.FileRange { + return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} } // Unpin releases the reference held by prs. @@ -857,7 +856,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf } // incPrivateRef acquires a reference on private pages in fr. -func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { +func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) { mm.privateRefs.mu.Lock() defer mm.privateRefs.mu.Unlock() refSet := &mm.privateRefs.refs @@ -878,8 +877,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { } // decPrivateRef releases a reference on private pages in fr. -func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) { - var freed []platform.FileRange +func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) { + var freed []memmap.FileRange mm.privateRefs.mu.Lock() refSet := &mm.privateRefs.refs @@ -951,7 +950,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa // Discard internal mappings instead of trying to merge them, since merging // them requires an allocation and getting them again from the - // platform.File might not. + // memmap.File might not. pma1.internalMappings = safemem.BlockSeq{} return pma1, true } @@ -1012,12 +1011,12 @@ func (pseg pmaIterator) getInternalMappingsLocked() error { return nil } -func (pseg pmaIterator) fileRange() platform.FileRange { +func (pseg pmaIterator) fileRange() memmap.FileRange { return pseg.fileRangeOf(pseg.Range()) } // Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0. -func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { +func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange { if checkInvariants { if !pseg.Ok() { panic("terminal pma iterator") @@ -1032,5 +1031,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { pma := pseg.ValuePtr() pstart := pseg.Start() - return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} + return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} } diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 9ad52082d..0e142fb11 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -35,7 +34,7 @@ type SpecialMappable struct { refs.AtomicRefCount mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange name string } @@ -44,7 +43,7 @@ type SpecialMappable struct { // SpecialMappable will use the given name in /proc/[pid]/maps. // // Preconditions: fr.Length() != 0. -func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable { +func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable { m := SpecialMappable{mfp: mfp, fr: fr, name: name} m.EnableLeakCheck("mm.SpecialMappable") return &m @@ -126,7 +125,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider { // FileRange returns the offsets into MemoryFileProvider().MemoryFile() that // store the SpecialMappable's contents. -func (m *SpecialMappable) FileRange() platform.FileRange { +func (m *SpecialMappable) FileRange() memmap.FileRange { return m.fr } diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index e1fcb175f..7a3311a70 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -36,14 +36,14 @@ go_template_instance( "trackGaps": "1", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "usage", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "usageInfo", "Functions": "usageSetFunctions", }, @@ -56,14 +56,14 @@ go_template_instance( "minDegree": "10", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "reclaim", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "reclaimSetValue", "Functions": "reclaimSetFunctions", }, @@ -89,7 +89,7 @@ go_library( "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/hostmm", - "//pkg/sentry/platform", + "//pkg/sentry/memmap", "//pkg/sentry/usage", "//pkg/state", "//pkg/state/wire", diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index afab97c0a..3243d7214 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -33,14 +33,14 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostmm" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) -// MemoryFile is a platform.File whose pages may be allocated to arbitrary +// MemoryFile is a memmap.File whose pages may be allocated to arbitrary // users. type MemoryFile struct { // opts holds options passed to NewMemoryFile. opts is immutable. @@ -372,7 +372,7 @@ func (f *MemoryFile) Destroy() { // to Allocate. // // Preconditions: length must be page-aligned and non-zero. -func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) { +func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) { if length == 0 || length%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid allocation length: %#x", length)) } @@ -390,7 +390,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Find a range in the underlying file. fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment) if !ok { - return platform.FileRange{}, syserror.ENOMEM + return memmap.FileRange{}, syserror.ENOMEM } // Expand the file if needed. @@ -398,7 +398,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Round the new file size up to be chunk-aligned. newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask if err := f.file.Truncate(newFileSize); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } f.fileSize = newFileSize f.mappingsMu.Lock() @@ -416,7 +416,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi bs[i] = 0 } }); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } } if !f.usage.Add(fr, usageInfo{ @@ -439,7 +439,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // space for mappings to be allocated downwards. // // Precondition: alignment must be a power of 2. -func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) { +func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) { alignmentMask := alignment - 1 // Search for space in existing gaps, starting at the current end of the @@ -461,7 +461,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 break } if start := unalignedStart &^ alignmentMask; start >= gap.Start() { - return platform.FileRange{start, start + length}, true + return memmap.FileRange{start, start + length}, true } gap = gap.PrevLargeEnoughGap(length) @@ -475,7 +475,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 min = (min + alignmentMask) &^ alignmentMask if min+length < min { // Overflow: allocation would exceed the range of uint64. - return platform.FileRange{}, false + return memmap.FileRange{}, false } // Determine the minimum file size required to fit this allocation at its end. @@ -484,7 +484,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 if newFileSize <= fileSize { if fileSize != 0 { // Overflow: allocation would exceed the range of int64. - return platform.FileRange{}, false + return memmap.FileRange{}, false } newFileSize = chunkSize } @@ -496,7 +496,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 continue } if start := unalignedStart &^ alignmentMask; start >= min { - return platform.FileRange{start, start + length}, true + return memmap.FileRange{start, start + length}, true } } } @@ -508,22 +508,22 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 // by r.ReadToBlocks(), it returns that error. // // Preconditions: length > 0. length must be page-aligned. -func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) { +func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) { fr, err := f.Allocate(length, kind) if err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } dsts, err := f.MapInternal(fr, usermem.Write) if err != nil { f.DecRef(fr) - return platform.FileRange{}, err + return memmap.FileRange{}, err } n, err := safemem.ReadFullToBlocks(r, dsts) un := uint64(usermem.Addr(n).RoundDown()) if un < length { // Free unused memory and update fr to contain only the memory that is // still allocated. - f.DecRef(platform.FileRange{fr.Start + un, fr.End}) + f.DecRef(memmap.FileRange{fr.Start + un, fr.End}) fr.End = fr.Start + un } return fr, err @@ -540,7 +540,7 @@ const ( // will read zeroes. // // Preconditions: fr.Length() > 0. -func (f *MemoryFile) Decommit(fr platform.FileRange) error { +func (f *MemoryFile) Decommit(fr memmap.FileRange) error { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -560,7 +560,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error { return nil } -func (f *MemoryFile) markDecommitted(fr platform.FileRange) { +func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() // Since we're changing the knownCommitted attribute, we need to merge @@ -581,8 +581,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) { f.usage.MergeRange(fr) } -// IncRef implements platform.File.IncRef. -func (f *MemoryFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (f *MemoryFile) IncRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -600,8 +600,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) { f.usage.MergeAdjacent(fr) } -// DecRef implements platform.File.DecRef. -func (f *MemoryFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (f *MemoryFile) DecRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -637,8 +637,8 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { } } -// MapInternal implements platform.File.MapInternal. -func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { if !fr.WellFormed() || fr.Length() == 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -664,7 +664,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) ( // forEachMappingSlice invokes fn on a sequence of byte slices that // collectively map all bytes in fr. -func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error { +func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error { mappings := f.mappings.Load().([]uintptr) for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize { chunk := int(chunkStart >> chunkShift) @@ -944,7 +944,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( continue case !populated && populatedRun: // Finish the run by changing this segment. - runRange := platform.FileRange{ + runRange := memmap.FileRange{ Start: r.Start + uint64(populatedRunStart*usermem.PageSize), End: r.Start + uint64(i*usermem.PageSize), } @@ -1009,7 +1009,7 @@ func (f *MemoryFile) File() *os.File { return f.file } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (f *MemoryFile) FD() int { return int(f.file.Fd()) } @@ -1090,13 +1090,13 @@ func (f *MemoryFile) runReclaim() { // // Note that there returned range will be removed from tracking. It // must be reclaimed (removed from f.usage) at this point. -func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { +func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) { f.mu.Lock() defer f.mu.Unlock() for { for { if f.destroyed { - return platform.FileRange{}, false + return memmap.FileRange{}, false } if f.reclaimable { break @@ -1120,7 +1120,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { } } -func (f *MemoryFile) markReclaimed(fr platform.FileRange) { +func (f *MemoryFile) markReclaimed(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() seg := f.usage.FindSegment(fr.Start) @@ -1222,11 +1222,11 @@ func (usageSetFunctions) MaxKey() uint64 { func (usageSetFunctions) ClearValue(val *usageInfo) { } -func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) { +func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) { return val1, val1 == val2 } -func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { +func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { return val, val } @@ -1270,10 +1270,10 @@ func (reclaimSetFunctions) MaxKey() uint64 { func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) { } -func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { +func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { return reclaimSetValue{}, true } -func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { +func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { return reclaimSetValue{}, reclaimSetValue{} } diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 453241eca..209b28053 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,39 +1,21 @@ load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -go_template_instance( - name = "file_range", - out = "file_range.go", - package = "platform", - prefix = "File", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - go_library( name = "platform", srcs = [ "context.go", - "file_range.go", "mmap_min_addr.go", "platform.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/atomicbitops", "//pkg/context", - "//pkg/log", - "//pkg/safecopy", - "//pkg/safemem", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/usage", - "//pkg/syserror", + "//pkg/sentry/memmap", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 10a10bfe2..b5d27a72a 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/ring0", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index faf1d5e1c..98a3e539d 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" @@ -150,7 +151,7 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem. } // MapFile implements platform.AddressSpace.MapFile. -func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { as.mu.Lock() defer as.mu.Unlock() diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index 171513f3f..4b13eec30 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -22,9 +22,9 @@ import ( "os" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/usermem" ) @@ -207,7 +207,7 @@ type AddressSpace interface { // Preconditions: addr and fr must be page-aligned. fr.Length() > 0. // at.Any() == true. At least one reference must be held on all pages in // fr, and must continue to be held as long as pages are mapped. - MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error + MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error // Unmap unmaps the given range. // @@ -310,52 +310,6 @@ func (f SegmentationFault) Error() string { return fmt.Sprintf("segmentation fault at %#x", f.Addr) } -// File represents a host file that may be mapped into an AddressSpace. -type File interface { - // All pages in a File are reference-counted. - - // IncRef increments the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. (The File - // interface does not provide a way to acquire an initial reference; - // implementors may define mechanisms for doing so.) - IncRef(fr FileRange) - - // DecRef decrements the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. - DecRef(fr FileRange) - - // MapInternal returns a mapping of the given file offsets in the invoking - // process' address space for reading and writing. - // - // Note that fr.Start and fr.End need not be page-aligned. - // - // Preconditions: fr.Length() > 0. At least one reference must be held on - // all pages in fr. - // - // Postconditions: The returned mapping is valid as long as at least one - // reference is held on the mapped pages. - MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) - - // FD returns the file descriptor represented by the File. - // - // The only permitted operation on the returned file descriptor is to map - // pages from it consistent with the requirements of AddressSpace.MapFile. - FD() int -} - -// FileRange represents a range of uint64 offsets into a File. -// -// type FileRange <generated using go_generics> - -// String implements fmt.Stringer.String. -func (fr FileRange) String() string { - return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) -} - // Requirements is used to specify platform specific requirements. type Requirements struct { // RequiresCurrentPIDNS indicates that the sandbox has to be started in the diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 30402c2df..29fd23cc3 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/hostcpu", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sync", diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 2389423b0..c990f3454 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -616,7 +617,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp } // MapFile implements platform.AddressSpace.MapFile. -func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { var flags int if precommit { flags |= syscall.MAP_POPULATE diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD index 83b80c8bc..a5e84658a 100644 --- a/pkg/test/dockerutil/BUILD +++ b/pkg/test/dockerutil/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,6 +10,7 @@ go_library( "dockerutil.go", "exec.go", "network.go", + "profile.go", ], visibility = ["//:sandbox"], deps = [ @@ -23,3 +24,19 @@ go_library( "@com_github_docker_go_connections//nat:go_default_library", ], ) + +go_test( + name = "profile_test", + size = "large", + srcs = [ + "profile_test.go", + ], + library = ":dockerutil", + tags = [ + # Requires docker and runsc to be configured before test runs. + # Also requires the test to be run as root. + "manual", + "local", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md new file mode 100644 index 000000000..870292096 --- /dev/null +++ b/pkg/test/dockerutil/README.md @@ -0,0 +1,86 @@ +# dockerutil + +This package is for creating and controlling docker containers for testing +runsc, gVisor's docker/kubernetes binary. A simple test may look like: + +``` + func TestSuperCool(t *testing.T) { + ctx := context.Background() + c := dockerutil.MakeContainer(ctx, t) + got, err := c.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine" + }, "echo", "super cool") + if err != nil { + t.Fatalf("err was not nil: %v", err) + } + want := "super cool" + if !strings.Contains(got, want){ + t.Fatalf("want: %s, got: %s", want, got) + } + } +``` + +For further examples, see many of our end to end tests elsewhere in the repo, +such as those in //test/e2e or benchmarks at //test/benchmarks. + +dockerutil uses the "official" docker golang api, which is +[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil +is a thin wrapper around this API, allowing desired new use cases to be easily +implemented. + +## Profiling + +dockerutil is capable of generating profiles. Currently, the only option is to +use pprof profiles generated by `runsc debug`. The profiler will generate Block, +CPU, Heap, Goroutine, and Mutex profiles. To generate profiles: + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or + `--vfs2`. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles run: + +``` +make sudo TARGETS=//path/to:target \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` + +Container name in most tests and benchmarks in gVisor is usually the test name +and some random characters like so: +`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2` + +Profiling requires root as runsc debug inspects running containers in /var/run +among other things. + +### Writing for Profiling + +The below shows an example of using profiles with dockerutil. + +``` +func TestSuperCool(t *testing.T){ + ctx := context.Background() + // profiled and using runtime from dockerutil.runtime flag + profiled := MakeContainer() + + // not profiled and using runtime runc + native := MakeNativeContainer() + + err := profiled.Spawn(ctx, RunOpts{ + Image: "some/image", + }, "sleep", "100000") + // profiling has begun here + ... + expensive setup that I don't want to profile. + ... + profiled.RestartProfiles() + // profiled activity +} +``` + +In the above example, `profiled` would be profiled and `native` would not. The +call to `RestartProfiles()` restarts the clock on profiling. This is useful if +the main activity being tested is done with `docker exec` or `container.Spawn()` +followed by one or more `container.Exec()` calls. diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 17acdaf6f..b59503188 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -43,15 +43,21 @@ import ( // See: https://pkg.go.dev/github.com/docker/docker. type Container struct { Name string - Runtime string + runtime string logger testutil.Logger client *client.Client id string mounts []mount.Mount links []string - cleanups []func() copyErr error + cleanups []func() + + // Profiles are profiles added to this container. They contain methods + // that are run after Creation, Start, and Cleanup of this Container, along + // a handle to restart the profile. Generally, tests/benchmarks using + // profiles need to run as root. + profiles []Profile // Stores streams attached to the container. Used by WaitForOutputSubmatch. streams types.HijackedResponse @@ -106,7 +112,19 @@ type RunOpts struct { // MakeContainer sets up the struct for a Docker container. // // Names of containers will be unique. +// Containers will check flags for profiling requests. func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { + c := MakeNativeContainer(ctx, logger) + c.runtime = *runtime + if p := MakePprofFromFlags(c); p != nil { + c.AddProfile(p) + } + return c +} + +// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native +// containers aren't profiled. +func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { // Slashes are not allowed in container names. name := testutil.RandomID(logger.Name()) name = strings.ReplaceAll(name, "/", "-") @@ -114,20 +132,33 @@ func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { if err != nil { return nil } - client.NegotiateAPIVersion(ctx) - return &Container{ logger: logger, Name: name, - Runtime: *runtime, + runtime: "", client: client, } } +// AddProfile adds a profile to this container. +func (c *Container) AddProfile(p Profile) { + c.profiles = append(c.profiles, p) +} + +// RestartProfiles calls Restart on all profiles for this container. +func (c *Container) RestartProfiles() error { + for _, profile := range c.profiles { + if err := profile.Restart(c); err != nil { + return err + } + } + return nil +} + // Spawn is analogous to 'docker run -d'. func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error { - if err := c.create(ctx, r, args); err != nil { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { return err } return c.Start(ctx) @@ -153,7 +184,7 @@ func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) // Run is analogous to 'docker run'. func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) { - if err := c.create(ctx, r, args); err != nil { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { return "", err } @@ -181,27 +212,25 @@ func (c *Container) MakeLink(target string) string { // CreateFrom creates a container from the given configs. func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { - cont, err := c.client.ContainerCreate(ctx, conf, hostconf, netconf, c.Name) - if err != nil { - return err - } - c.id = cont.ID - return nil + return c.create(ctx, conf, hostconf, netconf) } // Create is analogous to 'docker create'. func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error { - return c.create(ctx, r, args) + return c.create(ctx, c.config(r, args), c.hostConfig(r), nil) } -func (c *Container) create(ctx context.Context, r RunOpts, args []string) error { - conf := c.config(r, args) - hostconf := c.hostConfig(r) +func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name) if err != nil { return err } c.id = cont.ID + for _, profile := range c.profiles { + if err := profile.OnCreate(c); err != nil { + return fmt.Errorf("OnCreate method failed with: %v", err) + } + } return nil } @@ -227,7 +256,7 @@ func (c *Container) hostConfig(r RunOpts) *container.HostConfig { c.mounts = append(c.mounts, r.Mounts...) return &container.HostConfig{ - Runtime: c.Runtime, + Runtime: c.runtime, Mounts: c.mounts, PublishAllPorts: true, Links: r.Links, @@ -261,8 +290,15 @@ func (c *Container) Start(ctx context.Context) error { c.cleanups = append(c.cleanups, func() { c.streams.Close() }) - - return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}) + if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { + return fmt.Errorf("ContainerStart failed: %v", err) + } + for _, profile := range c.profiles { + if err := profile.OnStart(c); err != nil { + return fmt.Errorf("OnStart method failed: %v", err) + } + } + return nil } // Stop is analogous to 'docker stop'. @@ -482,6 +518,12 @@ func (c *Container) Remove(ctx context.Context) error { // CleanUp kills and deletes the container (best effort). func (c *Container) CleanUp(ctx context.Context) { + // Execute profile cleanups before the container goes down. + for _, profile := range c.profiles { + profile.OnCleanUp(c) + } + // Forget profiles. + c.profiles = nil // Kill the container. if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") { // Just log; can't do anything here. diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index df09babf3..5a9dd8bd8 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -25,6 +25,7 @@ import ( "os/exec" "regexp" "strconv" + "time" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -42,6 +43,26 @@ var ( // config is the default Docker daemon configuration path. config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths") + + // The following flags are for the "pprof" profiler tool. + + // pprofBaseDir allows the user to change the directory to which profiles are + // written. By default, profiles will appear under: + // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof. + pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") + + // duration is the max duration `runsc debug` will run and capture profiles. + // If the container's clean up method is called prior to duration, the + // profiling process will be killed. + duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds") + + // The below flags enable each type of profile. Multiple profiles can be + // enabled for each run. + pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") + pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") + pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug") + pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") + pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug") ) // EnsureSupportedDockerVersion checks if correct docker is installed. diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go new file mode 100644 index 000000000..1fab33083 --- /dev/null +++ b/pkg/test/dockerutil/profile.go @@ -0,0 +1,152 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "time" +) + +// Profile represents profile-like operations on a container, +// such as running perf or pprof. It is meant to be added to containers +// such that the container type calls the Profile during its lifecycle. +type Profile interface { + // OnCreate is called just after the container is created when the container + // has a valid ID (e.g. c.ID()). + OnCreate(c *Container) error + + // OnStart is called just after the container is started when the container + // has a valid Pid (e.g. c.SandboxPid()). + OnStart(c *Container) error + + // Restart restarts the Profile on request. + Restart(c *Container) error + + // OnCleanUp is called during the container's cleanup method. + // Cleanups should just log errors if they have them. + OnCleanUp(c *Container) error +} + +// Pprof is for running profiles with 'runsc debug'. Pprof workloads +// should be run as root and ONLY against runsc sandboxes. The runtime +// should have --profile set as an option in /etc/docker/daemon.json in +// order for profiling to work with Pprof. +type Pprof struct { + BasePath string // path to put profiles + BlockProfile bool + CPUProfile bool + GoRoutineProfile bool + HeapProfile bool + MutexProfile bool + Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. + shouldRun bool + cmd *exec.Cmd + stdout io.ReadCloser + stderr io.ReadCloser +} + +// MakePprofFromFlags makes a Pprof profile from flags. +func MakePprofFromFlags(c *Container) *Pprof { + if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) { + return nil + } + return &Pprof{ + BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), + BlockProfile: *pprofBlock, + CPUProfile: *pprofCPU, + GoRoutineProfile: *pprofGo, + HeapProfile: *pprofHeap, + MutexProfile: *pprofMutex, + Duration: *duration, + } +} + +// OnCreate implements Profile.OnCreate. +func (p *Pprof) OnCreate(c *Container) error { + return os.MkdirAll(p.BasePath, 0755) +} + +// OnStart implements Profile.OnStart. +func (p *Pprof) OnStart(c *Container) error { + path, err := RuntimePath() + if err != nil { + return fmt.Errorf("failed to get runtime path: %v", err) + } + + // The root directory of this container's runtime. + root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) + // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`. + args := []string{root, "debug"} + args = append(args, p.makeProfileArgs(c)...) + args = append(args, c.ID()) + + // Best effort wait until container is running. + for now := time.Now(); time.Since(now) < 5*time.Second; { + if status, err := c.Status(context.Background()); err != nil { + return fmt.Errorf("failed to get status with: %v", err) + + } else if status.Running { + break + } + time.Sleep(500 * time.Millisecond) + } + p.cmd = exec.Command(path, args...) + if err := p.cmd.Start(); err != nil { + return fmt.Errorf("process failed: %v", err) + } + return nil +} + +// Restart implements Profile.Restart. +func (p *Pprof) Restart(c *Container) error { + p.OnCleanUp(c) + return p.OnStart(c) +} + +// OnCleanUp implements Profile.OnCleanup +func (p *Pprof) OnCleanUp(c *Container) error { + defer func() { p.cmd = nil }() + if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() { + return p.cmd.Process.Kill() + } + return nil +} + +// makeProfileArgs turns Pprof fields into runsc debug flags. +func (p *Pprof) makeProfileArgs(c *Container) []string { + var ret []string + if p.BlockProfile { + ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof"))) + } + if p.CPUProfile { + ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof"))) + } + if p.GoRoutineProfile { + ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof"))) + } + if p.HeapProfile { + ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof"))) + } + if p.MutexProfile { + ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof"))) + } + ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration)) + return ret +} diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go new file mode 100644 index 000000000..b7b4d7618 --- /dev/null +++ b/pkg/test/dockerutil/profile_test.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" +) + +type testCase struct { + name string + pprof Pprof + expectedFiles []string +} + +func TestPprof(t *testing.T) { + // Basepath and expected file names for each type of profile. + basePath := "/tmp/test/profile" + block := "block.pprof" + cpu := "cpu.pprof" + goprofle := "go.pprof" + heap := "heap.pprof" + mutex := "mutex.pprof" + + testCases := []testCase{ + { + name: "Cpu", + pprof: Pprof{ + BasePath: basePath, + CPUProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{cpu}, + }, + { + name: "All", + pprof: Pprof{ + BasePath: basePath, + BlockProfile: true, + CPUProfile: true, + GoRoutineProfile: true, + HeapProfile: true, + MutexProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{block, cpu, goprofle, heap, mutex}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + c := MakeContainer(ctx, t) + // Set basepath to include the container name so there are no conflicts. + tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name) + c.AddProfile(&tc.pprof) + + func() { + defer c.CleanUp(ctx) + // Start a container. + if err := c.Spawn(ctx, RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { + t.Fatalf("run failed with: %v", err) + } + + if status, err := c.Status(context.Background()); !status.Running { + t.Fatalf("container is not yet running: %+v err: %v", status, err) + } + + // End early if the expected files exist and have data. + for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) { + if err := checkFiles(tc); err == nil { + break + } + } + }() + + // Check all expected files exist and have data. + if err := checkFiles(tc); err != nil { + t.Fatalf(err.Error()) + } + }) + } +} + +func checkFiles(tc testCase) error { + for _, file := range tc.expectedFiles { + stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file)) + if err != nil { + return fmt.Errorf("stat failed with: %v", err) + } else if stat.Size() < 1 { + return fmt.Errorf("file not written to: %+v", stat) + } + } + return nil +} + +func TestMain(m *testing.M) { + EnsureSupportedDockerVersion() + os.Exit(m.Run()) +} diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index cd76645bd..5e8247bc8 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -643,7 +643,9 @@ func TestExec(t *testing.T) { if err != nil { t.Fatalf("error creating temporary directory: %v", err) } - cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100", dir) + // Note that some shells may exec the final command in a sequence as + // an optimization. We avoid this here by adding the exit 0. + cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100 && exit 0", dir) spec := testutil.NewSpecWithArgs("sh", "-c", cmd) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index 82a46910e..ebefeacf2 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -48,36 +48,6 @@ const ( openFlags = syscall.O_NOFOLLOW | syscall.O_CLOEXEC ) -type fileType int - -const ( - regular fileType = iota - directory - symlink - socket - unknown -) - -// String implements fmt.Stringer. -func (f fileType) String() string { - switch f { - case regular: - return "regular" - case directory: - return "directory" - case symlink: - return "symlink" - case socket: - return "socket" - } - return "unknown" -} - -// ControlSocketAddr generates an abstract unix socket name for the given id. -func ControlSocketAddr(id string) string { - return fmt.Sprintf("\x00runsc-gofer.%s", id) -} - // Config sets configuration options for each attach point. type Config struct { // ROMount is set to true if this is a readonly mount. @@ -199,8 +169,6 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { // entire file up when it's opened in write mode, and would perform badly when // multiple files are only being opened for read (esp. startup). type localFile struct { - p9.DefaultWalkGetAttr - // attachPoint is the attachPoint that serves this localFile. attachPoint *attachPoint @@ -220,8 +188,11 @@ type localFile struct { // if localFile isn't opened. mode p9.OpenFlags - // ft is the fileType for this file. - ft fileType + // fileType for this file. It is equivalent to: + // syscall.Stat_t.Mode & syscall.S_IFMT + fileType uint32 + + qid p9.QID // readDirMu protects against concurrent Readdir calls. readDirMu sync.Mutex @@ -308,29 +279,24 @@ func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, return nil, false, extractErrno(err) } -func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) { - var ft fileType +func checkSupportedFileType(stat syscall.Stat_t, permitSocket bool) error { switch stat.Mode & syscall.S_IFMT { - case syscall.S_IFREG: - ft = regular - case syscall.S_IFDIR: - ft = directory - case syscall.S_IFLNK: - ft = symlink + case syscall.S_IFREG, syscall.S_IFDIR, syscall.S_IFLNK: + return nil + case syscall.S_IFSOCK: if !permitSocket { - return unknown, syscall.EPERM + return syscall.EPERM } - ft = socket + return nil + default: - return unknown, syscall.EPERM + return syscall.EPERM } - return ft, nil } func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat syscall.Stat_t) (*localFile, error) { - ft, err := getSupportedFileType(stat, a.conf.HostUDS) - if err != nil { + if err := checkSupportedFileType(stat, a.conf.HostUDS); err != nil { return nil, err } @@ -339,7 +305,8 @@ func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat hostPath: path, file: file, mode: invalidMode, - ft: ft, + fileType: stat.Mode & syscall.S_IFMT, + qid: a.makeQID(stat), controlReadable: readable, }, nil } @@ -359,7 +326,7 @@ func newFDMaybe(file *fd.FD) *fd.FD { // fd is blocking; non-blocking is required. if err := syscall.SetNonblock(dup.FD(), true); err != nil { - dup.Close() + _ = dup.Close() return nil } return dup @@ -409,16 +376,8 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { } } - stat, err := fstat(newFile.FD()) - if err != nil { - if newFile != l.file { - newFile.Close() - } - return nil, p9.QID{}, 0, extractErrno(err) - } - var fd *fd.FD - if stat.Mode&syscall.S_IFMT == syscall.S_IFREG { + if l.fileType == syscall.S_IFREG { // Donate FD for regular files only. fd = newFDMaybe(newFile) } @@ -431,7 +390,7 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { l.file = newFile } l.mode = flags & p9.OpenFlagsModeMask - return fd, l.attachPoint.makeQID(stat), 0, nil + return fd, l.qid, 0, nil } // Create implements p9.File. @@ -459,7 +418,7 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid return nil, nil, p9.QID{}, 0, extractErrno(err) } cu := cleanup.Make(func() { - child.Close() + _ = child.Close() // Best effort attempt to remove the file in case of failure. if err := syscall.Unlinkat(l.file.FD(), name); err != nil { log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, name), err) @@ -480,10 +439,12 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid hostPath: path.Join(l.hostPath, name), file: child, mode: mode, + fileType: syscall.S_IFREG, + qid: l.attachPoint.makeQID(stat), } cu.Release() - return newFDMaybe(c.file), c, l.attachPoint.makeQID(stat), 0, nil + return newFDMaybe(c.file), c, c.qid, 0, nil } // Mkdir implements p9.File. @@ -529,19 +490,34 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) // Walk implements p9.File. func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) { + qids, file, _, err := l.walk(names) + return qids, file, err +} + +// WalkGetAttr implements p9.File. +func (l *localFile) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) { + qids, file, stat, err := l.walk(names) + if err != nil { + return nil, nil, p9.AttrMask{}, p9.Attr{}, err + } + mask, attr := l.fillAttr(stat) + return qids, file, mask, attr, nil +} + +func (l *localFile) walk(names []string) ([]p9.QID, p9.File, syscall.Stat_t, error) { // Duplicate current file if 'names' is empty. if len(names) == 0 { newFile, readable, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) { return reopenProcFd(l.file, openFlags|mode) }) if err != nil { - return nil, nil, extractErrno(err) + return nil, nil, syscall.Stat_t{}, extractErrno(err) } stat, err := fstat(newFile.FD()) if err != nil { - newFile.Close() - return nil, nil, extractErrno(err) + _ = newFile.Close() + return nil, nil, syscall.Stat_t{}, extractErrno(err) } c := &localFile{ @@ -549,36 +525,39 @@ func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) { hostPath: l.hostPath, file: newFile, mode: invalidMode, + fileType: l.fileType, + qid: l.attachPoint.makeQID(stat), controlReadable: readable, } - return []p9.QID{l.attachPoint.makeQID(stat)}, c, nil + return []p9.QID{c.qid}, c, stat, nil } var qids []p9.QID + var lastStat syscall.Stat_t last := l for _, name := range names { f, path, readable, err := openAnyFileFromParent(last, name) if last != l { - last.Close() + _ = last.Close() } if err != nil { - return nil, nil, extractErrno(err) + return nil, nil, syscall.Stat_t{}, extractErrno(err) } - stat, err := fstat(f.FD()) + lastStat, err = fstat(f.FD()) if err != nil { - f.Close() - return nil, nil, extractErrno(err) + _ = f.Close() + return nil, nil, syscall.Stat_t{}, extractErrno(err) } - c, err := newLocalFile(last.attachPoint, f, path, readable, stat) + c, err := newLocalFile(last.attachPoint, f, path, readable, lastStat) if err != nil { - f.Close() - return nil, nil, extractErrno(err) + _ = f.Close() + return nil, nil, syscall.Stat_t{}, extractErrno(err) } - qids = append(qids, l.attachPoint.makeQID(stat)) + qids = append(qids, c.qid) last = c } - return qids, last, nil + return qids, last, lastStat, nil } // StatFS implements p9.File. @@ -618,7 +597,11 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) if err != nil { return p9.QID{}, p9.AttrMask{}, p9.Attr{}, extractErrno(err) } + mask, attr := l.fillAttr(stat) + return l.qid, mask, attr, nil +} +func (l *localFile) fillAttr(stat syscall.Stat_t) (p9.AttrMask, p9.Attr) { attr := p9.Attr{ Mode: p9.FileMode(stat.Mode), UID: p9.UID(stat.Uid), @@ -647,8 +630,7 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) MTime: true, CTime: true, } - - return l.attachPoint.makeQID(stat), valid, attr, nil + return valid, attr } // SetAttr implements p9.File. Due to mismatch in file API, options @@ -689,7 +671,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { // Check if it's possible to use cached file, or if another one needs to be // opened for write. f := l.file - if l.ft == regular && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { + if l.fileType == syscall.S_IFREG && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { var err error f, err = reopenProcFd(l.file, openFlags|os.O_WRONLY) if err != nil { @@ -745,7 +727,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } } - if l.ft == symlink { + if l.fileType == syscall.S_IFLNK { // utimensat operates different that other syscalls. To operate on a // symlink it *requires* AT_SYMLINK_NOFOLLOW with dirFD and a non-empty // name. @@ -929,7 +911,7 @@ func (l *localFile) Link(target p9.File, newName string) error { } // Mknod implements p9.File. -func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid p9.UID, gid p9.GID) (p9.QID, error) { +func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, _ p9.UID, _ p9.GID) (p9.QID, error) { conf := l.attachPoint.conf if conf.ROMount { if conf.PanicOnWrite { @@ -1127,13 +1109,13 @@ func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { } if err := syscall.SetNonblock(f, true); err != nil { - syscall.Close(f) + _ = syscall.Close(f) return nil, err } sa := syscall.SockaddrUnix{Name: l.hostPath} if err := syscall.Connect(f, &sa); err != nil { - syscall.Close(f) + _ = syscall.Close(f) return nil, err } diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index 5b37e6aa1..94f167417 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -32,7 +32,7 @@ import ( var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} var ( - allTypes = []fileType{regular, directory, symlink} + allTypes = []uint32{syscall.S_IFREG, syscall.S_IFDIR, syscall.S_IFLNK} // allConfs is set in init(). allConfs []Config @@ -109,24 +109,37 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error { } type state struct { - root *localFile - file *localFile - conf Config - ft fileType + root *localFile + file *localFile + conf Config + fileType uint32 } func (s state) String() string { - return fmt.Sprintf("type(%v)", s.ft) + return fmt.Sprintf("type(%v)", s.fileType) +} + +func typeName(fileType uint32) string { + switch fileType { + case syscall.S_IFREG: + return "file" + case syscall.S_IFDIR: + return "directory" + case syscall.S_IFLNK: + return "symlink" + default: + panic(fmt.Sprintf("invalid file type for test: %d", fileType)) + } } func runAll(t *testing.T, test func(*testing.T, state)) { runCustom(t, allTypes, allConfs, test) } -func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testing.T, state)) { +func runCustom(t *testing.T, types []uint32, confs []Config, test func(*testing.T, state)) { for _, c := range confs { for _, ft := range types { - name := fmt.Sprintf("%s/%v", configTestName(&c), ft) + name := fmt.Sprintf("%s/%s", configTestName(&c), typeName(ft)) t.Run(name, func(t *testing.T) { path, name, err := setup(ft) if err != nil { @@ -150,10 +163,10 @@ func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testin } st := state{ - root: root.(*localFile), - file: file.(*localFile), - conf: c, - ft: ft, + root: root.(*localFile), + file: file.(*localFile), + conf: c, + fileType: ft, } test(t, st) file.Close() @@ -163,7 +176,7 @@ func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testin } } -func setup(ft fileType) (string, string, error) { +func setup(fileType uint32) (string, string, error) { path, err := ioutil.TempDir(testutil.TmpDir(), "root-") if err != nil { return "", "", fmt.Errorf("ioutil.TempDir() failed, err: %v", err) @@ -181,26 +194,26 @@ func setup(ft fileType) (string, string, error) { defer root.Close() var name string - switch ft { - case regular: + switch fileType { + case syscall.S_IFREG: name = "file" _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err) } defer f.Close() - case directory: + case syscall.S_IFDIR: name = "dir" if _, err := root.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { return "", "", fmt.Errorf("root.MkDir(%q) failed, err: %v", name, err) } - case symlink: + case syscall.S_IFLNK: name = "symlink" if _, err := root.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { return "", "", fmt.Errorf("root.Symlink(%q) failed, err: %v", name, err) } default: - panic(fmt.Sprintf("unknown file type %v", ft)) + panic(fmt.Sprintf("unknown file type %v", fileType)) } return path, name, nil } @@ -214,7 +227,7 @@ func createFile(dir *localFile, name string) (*localFile, error) { } func TestReadWrite(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { child, err := createFile(s.file, "test") if err != nil { t.Fatalf("%v: createFile() failed, err: %v", s, err) @@ -244,7 +257,7 @@ func TestReadWrite(t *testing.T) { } func TestCreate(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { for i, flags := range allOpenFlags { _, l, _, _, err := s.file.Create(fmt.Sprintf("test-%d", i), flags, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { @@ -261,7 +274,7 @@ func TestCreate(t *testing.T) { // TestReadWriteDup tests that a file opened in any mode can be dup'ed and // reopened in any other mode. func TestReadWriteDup(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { child, err := createFile(s.file, "test") if err != nil { t.Fatalf("%v: createFile() failed, err: %v", s, err) @@ -303,7 +316,7 @@ func TestReadWriteDup(t *testing.T) { } func TestUnopened(t *testing.T) { - runCustom(t, []fileType{regular}, allConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFREG}, allConfs, func(t *testing.T, s state) { b := []byte("foobar") if _, err := s.file.WriteAt(b, 0); err != syscall.EBADF { t.Errorf("%v: WriteAt() should have failed, got: %v, expected: syscall.EBADF", s, err) @@ -325,7 +338,7 @@ func TestUnopened(t *testing.T) { // was open with O_PATH, but Open() was not checking for it and allowing the // control file to be reused. func TestOpenOPath(t *testing.T) { - runCustom(t, []fileType{regular}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFREG}, rwConfs, func(t *testing.T, s state) { // Fist remove all permissions on the file. if err := s.file.SetAttr(p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(0)}); err != nil { t.Fatalf("SetAttr(): %v", err) @@ -362,7 +375,7 @@ func TestSetAttrPerm(t *testing.T) { valid := p9.SetAttrMask{Permissions: true} attr := p9.SetAttr{Permissions: 0777} got, err := SetGetAttr(s.file, valid, attr) - if s.ft == symlink { + if s.fileType == syscall.S_IFLNK { if err == nil { t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) } @@ -383,7 +396,7 @@ func TestSetAttrSize(t *testing.T) { valid := p9.SetAttrMask{Size: true} attr := p9.SetAttr{Size: size} got, err := SetGetAttr(s.file, valid, attr) - if s.ft == symlink || s.ft == directory { + if s.fileType == syscall.S_IFLNK || s.fileType == syscall.S_IFDIR { if err == nil { t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) } @@ -465,7 +478,7 @@ func TestLink(t *testing.T) { } err = dir.Link(s.file, linkFile) - if s.ft == directory { + if s.fileType == syscall.S_IFDIR { if err != syscall.EPERM { t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: syscall.EPERM", s, linkFile, err) } @@ -523,7 +536,7 @@ func TestROMountPanics(t *testing.T) { } func TestWalkNotFound(t *testing.T) { - runCustom(t, []fileType{directory}, allConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, allConfs, func(t *testing.T, s state) { if _, _, err := s.file.Walk([]string{"nobody-here"}); err != syscall.ENOENT { t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: syscall.ENOENT", s, "nobody-here", err) } @@ -544,7 +557,7 @@ func TestWalkDup(t *testing.T) { } func TestReaddir(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { name := "dir" if _, err := s.file.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { t.Fatalf("%v: MkDir(%s) failed, err: %v", s, name, err) diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh deleted file mode 100755 index c49f988b8..000000000 --- a/scripts/benchmark.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -make load-all-images - -if [[ -z "${1:-}" ]]; then - target=$(query "attr(tags, manual, tests(//test/benchmarks/...))") -else - target="$1" -fi - -install_runsc_for_benchmarks benchmark - -echo $target -benchmark_runsc $target "${@:2}" diff --git a/scripts/common.sh b/scripts/common.sh index 36158654f..3ca699e4a 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -42,15 +42,6 @@ function test_runsc() { test --test_arg=--runtime=${RUNTIME} "$@" } -function benchmark_runsc() { - test_runsc -c opt \ - --nocache_test_results \ - --test_arg=-test.bench=. \ - --test_arg=-test.benchmem \ - --jobs=1 \ - "$@" -} - function install_runsc_for_test() { local -r test_name=$1 shift @@ -72,24 +63,6 @@ function install_runsc_for_test() { "$@" } -function install_runsc_for_benchmarks() { - local -r test_name=$1 - shift - if [[ -z "${test_name}" ]]; then - echo "Missing mandatory test name" - exit 1 - fi - - # Add test to the name, so it doesn't conflict with other runtimes. - set_runtime $(find_branch_name)_"${test_name}" - - # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name - # down to the runtime. - install_runsc "${RUNTIME}" \ - --TESTONLY-test-name-env=RUNSC_TEST_NAME \ - "$@" -} - # Installs the runsc with given runtime name. set_runtime must have been called # to set runtime and logs location. function install_runsc() { diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh index dce0a4085..07e9f3109 100755 --- a/scripts/docker_tests.sh +++ b/scripts/docker_tests.sh @@ -22,4 +22,6 @@ install_runsc_for_test docker test_runsc //test/image:image_test //test/e2e:integration_test install_runsc_for_test docker --vfs2 -test_runsc //test/image:image_test --test_filter=.*TestHelloWorld +IMAGE_FILTER="Hello|Httpd|Ruby|Stdio" +INTEGRATION_FILTER="LifeCycle|Pause|Connect|JobControl|Overlay|Exec|DirCreation/root" +test_runsc //test/e2e:integration_test //test/image:image_test --test_filter="${IMAGE_FILTER}|${INTEGRATION_FILTER}" diff --git a/test/benchmarks/README.md b/test/benchmarks/README.md index 9ff602cf1..d1bbabf6f 100644 --- a/test/benchmarks/README.md +++ b/test/benchmarks/README.md @@ -13,33 +13,51 @@ To run benchmarks you will need: * Docker installed (17.09.0 or greater). -The easiest way to run benchmarks is to use the script at -//scripts/benchmark.sh. +The easiest way to setup runsc for running benchmarks is to use the make file. +From the root directory: -If not using the script, you will need: +* Download images: `make load-all-images` +* Install runsc suitable for benchmarking, which should probably not have + strace or debug logs enabled. For example:`make configure RUNTIME=myrunsc + ARGS=--platform=kvm`. +* Restart docker: `sudo service docker restart` -* `runsc` configured with docker +You should now have a runtime with the following options configured in +`/etc/docker/daemon.json` -Note: benchmarks call the runtime by name. If docker can run it with -`--runtime=` flag, these tools should work. +``` +"myrunsc": { + "path": "/tmp/myrunsc/runsc", + "runtimeArgs": [ + "--debug-log", + "/tmp/bench/logs/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%", + "--platform=kvm" + ] + }, + +``` + +This runtime has been configured with a debugging off and strace logs off and is +using kvm for demonstration. ## Running benchmarks -The easiest way to run is with the script at //scripts/benchmarks.sh. The script -will run all benchmarks under //test/benchmarks if a target is not provided. +Given the runtime above runtime `myrunsc`, run benchmarks with the following: -```bash -./script/benchmarks.sh //path/to/target +``` +make sudo TARGETS=//path/to:target ARGS="--runtime=myrunsc -test.v \ + -test.bench=." OPTIONS="-c opt ``` -If you want to run benchmarks manually: - -* Run `make load-all-images` from `//` -* Run with: +For example, to run only the Iperf tests: -```bash -bazel test --test_arg=--runtime=RUNTIME -c opt --test_output=streamed --test_timeout=600 --test_arg=-test.bench=. --nocache_test_results //path/to/target ``` +make sudo TARGETS=//test/benchmarks/network:network_test \ + ARGS="--runtime=myrunsc -test.v -test.bench=Iperf" OPTIONS="-c opt" +``` + +Benchmarks are run with root as some benchmarks require root privileges to do +things like drop caches. ## Writing benchmarks @@ -69,6 +87,7 @@ var h harness.Harness func BenchmarkMyCoolOne(b *testing.B) { machine, err := h.GetMachine() // check err + defer machine.CleanUp() ctx := context.Background() container := machine.GetContainer(ctx, b) @@ -82,7 +101,7 @@ func BenchmarkMyCoolOne(b *testing.B) { Image: "benchmarks/my-cool-image", Env: []string{"MY_VAR=awesome"}, other options...see dockerutil - }, "sh", "-c", "echo MY_VAR" ...) + }, "sh", "-c", "echo MY_VAR") //check err b.StopTimer() @@ -107,12 +126,32 @@ Some notes on the above: flags, remote virtual machines (eventually), and other services. * Respect `b.N` in that users of the benchmark may want to "run for an hour" or something of the sort. -* Use the `b.ReportMetric` method to report custom metrics. +* Use the `b.ReportMetric()` method to report custom metrics. * Set the timer if time is useful for reporting. There isn't a way to turn off default metrics in testing.B (B/op, allocs/op, ns/op). * Take a look at dockerutil at //pkg/test/dockerutil to see all methods available from containers. The API is based on the "official" [docker API for golang](https://pkg.go.dev/mod/github.com/docker/docker). -* `harness.GetMachine` marks how many machines this tests needs. If you have a - client and server and to mark them as multiple machines, call it - `GetMachine` twice. +* `harness.GetMachine()` marks how many machines this tests needs. If you have + a client and server and to mark them as multiple machines, call + `harness.GetMachine()` twice. + +## Profiling + +For profiling, the runtime is required to have the `--profile` flag enabled. +This flag loosens seccomp filters so that the runtime can write profile data to +disk. This configuration is not recommended for production. + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile --platform=kvm --vfs2"`. The kvm and vfs2 flags are not + required, but are included for demonstration. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles fs_test test run: + +``` +make sudo TARGETS=//test/benchmarks/fs:fs_test \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD new file mode 100644 index 000000000..5e33465cd --- /dev/null +++ b/test/benchmarks/database/BUILD @@ -0,0 +1,28 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "database", + testonly = 1, + srcs = ["database.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "database_test", + size = "enormous", + srcs = [ + "redis_test.go", + ], + library = ":database", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/database/database.go b/test/benchmarks/database/database.go new file mode 100644 index 000000000..9eeb59f9a --- /dev/null +++ b/test/benchmarks/database/database.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package database holds benchmarks around database applications. +package database + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package database. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go new file mode 100644 index 000000000..6d39f4d66 --- /dev/null +++ b/test/benchmarks/database/redis_test.go @@ -0,0 +1,197 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// All possible operations from redis. Note: "ping" will +// run both PING_INLINE and PING_BUILD. +var operations []string = []string{ + "PING_INLINE", + "PING_BULK", + "SET", + "GET", + "INCR", + "LPUSH", + "RPUSH", + "LPOP", + "RPOP", + "SADD", + "HSET", + "SPOP", + "LRANGE_100", + "LRANGE_300", + "LRANGE_500", + "LRANGE_600", + "MSET", +} + +// BenchmarkRedis runs redis-benchmark against a redis instance and reports +// data in queries per second. Each is reported by named operation (e.g. LPUSH). +func BenchmarkRedis(b *testing.B) { + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + // Redis runs on port 6379 by default. + port := 6379 + ctx := context.Background() + + for _, operation := range operations { + b.Run(operation, func(b *testing.B) { + server := serverMachine.GetContainer(ctx, b) + defer server.CleanUp(ctx) + + // The redis docker container takes no arguments to run a redis server. + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + Ports: []int{port}, + }); err != nil { + b.Fatalf("failed to start redis server with: %v", err) + } + + if out, err := server.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatal("failed to get IP from server: %v", err) + } + + serverPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatal("failed to get IP from server: %v", err) + } + + if err = harness.WaitUntilServing(ctx, clientMachine, ip, serverPort); err != nil { + b.Fatalf("failed to start redis with: %v", err) + } + + // runs redis benchmark -t operation for 100K requests against server. + cmd := strings.Split( + fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", operation, ip, serverPort), " ") + + // There is no -t PING_BULK for redis-benchmark, so adjust the command in that case. + // Note that "ping" will run both PING_INLINE and PING_BULK. + if operation == "PING_BULK" { + cmd = strings.Split( + fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, serverPort), " ") + } + // Reset profiles and timer to begin the measurement. + server.RestartProfiles() + b.ResetTimer() + for i := 0; i < b.N; i++ { + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }, cmd...) + if err != nil { + b.Fatalf("redis-benchmark failed with: %v", err) + } + + // Stop time while we parse results. + b.StopTimer() + result, err := parseOperation(operation, out) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", out, err) + } + b.ReportMetric(result, operation) // operations per second + b.StartTimer() + } + }) + } +} + +// parseOperation grabs the metric operations per second from redis-benchmark output. +func parseOperation(operation, data string) (float64, error) { + re := regexp.MustCompile(fmt.Sprintf(`"%s( .*)?","(\d*\.\d*)"`, operation)) + match := re.FindStringSubmatch(data) + // If no match, simply don't add it to the result map. + if len(match) < 3 { + return 0.0, fmt.Errorf("could not find %s in %s", operation, data) + } + return strconv.ParseFloat(match[2], 64) +} + +// TestParser tests the parser on sample data. +func TestParser(t *testing.T) { + sampleData := ` + "PING_INLINE","48661.80" + "PING_BULK","50301.81" + "SET","48923.68" + "GET","49382.71" + "INCR","49975.02" + "LPUSH","49875.31" + "RPUSH","50276.52" + "LPOP","50327.12" + "RPOP","50556.12" + "SADD","49504.95" + "HSET","49504.95" + "SPOP","50025.02" + "LPUSH (needed to benchmark LRANGE)","48875.86" + "LRANGE_100 (first 100 elements)","33955.86" + "LRANGE_300 (first 300 elements)","16550.81" + "LRANGE_500 (first 450 elements)","13653.74" + "LRANGE_600 (first 600 elements)","11219.57" + "MSET (10 keys)","44682.75" + ` + wants := map[string]float64{ + "PING_INLINE": 48661.80, + "PING_BULK": 50301.81, + "SET": 48923.68, + "GET": 49382.71, + "INCR": 49975.02, + "LPUSH": 49875.31, + "RPUSH": 50276.52, + "LPOP": 50327.12, + "RPOP": 50556.12, + "SADD": 49504.95, + "HSET": 49504.95, + "SPOP": 50025.02, + "LRANGE_100": 33955.86, + "LRANGE_300": 16550.81, + "LRANGE_500": 13653.74, + "LRANGE_600": 11219.57, + "MSET": 44682.75, + } + for op, want := range wants { + if got, err := parseOperation(op, sampleData); err != nil { + t.Fatalf("failed to parse %s: %v", op, err) + } else if want != got { + t.Fatalf("wanted %f for op %s, got %f", want, op, got) + } + } +} diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go index fdcac1a7a..9b652fd43 100644 --- a/test/benchmarks/fs/bazel_test.go +++ b/test/benchmarks/fs/bazel_test.go @@ -15,6 +15,7 @@ package fs import ( "context" + "fmt" "strings" "testing" @@ -51,10 +52,10 @@ func BenchmarkABSL(b *testing.B) { workdir := "/abseil-cpp" - // Start a container. + // Start a container and sleep by an order of b.N. if err := container.Spawn(ctx, dockerutil.RunOpts{ Image: "benchmarks/absl", - }, "sleep", "1000"); err != nil { + }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil { b.Fatalf("run failed with: %v", err) } @@ -67,15 +68,21 @@ func BenchmarkABSL(b *testing.B) { workdir = "/tmp" + workdir } - // Drop Caches. - if bm.clearCache { - if out, err := machine.RunCommand("/bin/sh -c sync; echo 3 > /proc/sys/vm/drop_caches"); err != nil { - b.Fatalf("failed to drop caches: %v %s", err, out) - } - } - + // Restart profiles after the copy. + container.RestartProfiles() b.ResetTimer() + // Drop Caches and bazel clean should happen inside the loop as we may use + // time options with b.N. (e.g. Run for an hour.) for i := 0; i < b.N; i++ { + b.StopTimer() + // Drop Caches for clear cache runs. + if bm.clearCache { + if out, err := machine.RunCommand("/bin/sh", "-c", "sync && sysctl vm.drop_caches=3"); err != nil { + b.Skipf("failed to drop caches: %v %s. You probably need root.", err, out) + } + } + b.StartTimer() + got, err := container.Exec(ctx, dockerutil.ExecOpts{ WorkDir: workdir, }, "bazel", "build", "-c", "opt", "absl/base/...") @@ -88,6 +95,13 @@ func BenchmarkABSL(b *testing.B) { if !strings.Contains(got, want) { b.Fatalf("string %s not in: %s", want, got) } + // Clean bazel in case we use b.N. + _, err = container.Exec(ctx, dockerutil.ExecOpts{ + WorkDir: workdir, + }, "bazel", "clean") + if err != nil { + b.Fatalf("build failed with: %v", err) + } b.StartTimer() } }) diff --git a/test/benchmarks/harness/machine.go b/test/benchmarks/harness/machine.go index 93c0db9ce..88e5e841b 100644 --- a/test/benchmarks/harness/machine.go +++ b/test/benchmarks/harness/machine.go @@ -25,9 +25,14 @@ import ( // Machine describes a real machine for use in benchmarks. type Machine interface { - // GetContainer gets a container from the machine, + // GetContainer gets a container from the machine. The container uses the + // runtime under test and is profiled if requested by flags. GetContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container + // GetNativeContainer gets a native container from the machine. Native containers + // use runc by default and are not profiled. + GetNativeContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container + // RunCommand runs cmd on this machine. RunCommand(cmd string, args ...string) (string, error) @@ -47,6 +52,11 @@ func (l *localMachine) GetContainer(ctx context.Context, logger testutil.Logger) return dockerutil.MakeContainer(ctx, logger) } +// GetContainer implements Machine.GetContainer for localMachine. +func (l *localMachine) GetNativeContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container { + return dockerutil.MakeNativeContainer(ctx, logger) +} + // RunCommand implements Machine.RunCommand for localMachine. func (l *localMachine) RunCommand(cmd string, args ...string) (string, error) { c := exec.Command(cmd, args...) diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go index cc7de6426..bc551c582 100644 --- a/test/benchmarks/harness/util.go +++ b/test/benchmarks/harness/util.go @@ -27,12 +27,20 @@ import ( // IP:port. func WaitUntilServing(ctx context.Context, machine Machine, server net.IP, port int) error { var logger testutil.DefaultLogger = "netcat" - netcat := machine.GetContainer(ctx, logger) + netcat := machine.GetNativeContainer(ctx, logger) defer netcat.CleanUp(ctx) - cmd := fmt.Sprintf("while ! nc -zv %s %d; do true; done", server.String(), port) + cmd := fmt.Sprintf("while ! nc -zv %s %d; do true; done", server, port) _, err := netcat.Run(ctx, dockerutil.RunOpts{ Image: "packetdrill", }, "sh", "-c", cmd) return err } + +// DropCaches drops caches on the provided machine. Requires root. +func DropCaches(machine Machine) error { + if out, err := machine.RunCommand("/bin/sh", "-c", "sync | sysctl vm.drop_caches=3"); err != nil { + return fmt.Errorf("failed to drop caches: %v logs: %s", err, out) + } + return nil +} diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD new file mode 100644 index 000000000..6c41fc4f6 --- /dev/null +++ b/test/benchmarks/media/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "media", + testonly = 1, + srcs = ["media.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "media_test", + size = "large", + srcs = ["ffmpeg_test.go"], + library = ":media", + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go new file mode 100644 index 000000000..bfcfbab80 --- /dev/null +++ b/test/benchmarks/media/ffmpeg_test.go @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package media + +import ( + "context" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkFfmpeg runs ffmpeg in a container and records runtime. +// BenchmarkFfmpeg should run as root to drop caches. +func BenchmarkFfmpeg(b *testing.B) { + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + container := machine.GetContainer(ctx, b) + cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + b.StartTimer() + + if _, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/ffmpeg", + }, cmd...); err != nil { + b.Fatalf("failed to run container: %v", err) + } + } +} diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go new file mode 100644 index 000000000..c7b35b758 --- /dev/null +++ b/test/benchmarks/media/media.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package media holds benchmarks around media processing applications. +package media + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package media. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD index 16d267bc8..363041fb7 100644 --- a/test/benchmarks/network/BUILD +++ b/test/benchmarks/network/BUILD @@ -24,6 +24,7 @@ go_test( ], deps = [ "//pkg/test/dockerutil", + "//pkg/test/testutil", "//test/benchmarks/harness", ], ) diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go index f9afdf15f..fe23ca949 100644 --- a/test/benchmarks/network/httpd_test.go +++ b/test/benchmarks/network/httpd_test.go @@ -52,12 +52,12 @@ func BenchmarkHttpdConcurrency(b *testing.B) { defer serverMachine.CleanUp() // The test iterates over client concurrency, so set other parameters. - requests := 1000 + requests := 10000 concurrency := []int{1, 5, 10, 25} doc := docs["10Kb"] for _, c := range concurrency { - b.Run(fmt.Sprintf("%dConcurrency", c), func(b *testing.B) { + b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { runHttpd(b, clientMachine, serverMachine, doc, requests, c) }) } @@ -78,7 +78,7 @@ func BenchmarkHttpdDocSize(b *testing.B) { } defer serverMachine.CleanUp() - requests := 1000 + requests := 10000 concurrency := 1 for name, filename := range docs { @@ -129,7 +129,7 @@ func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, doc st harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) // Grab a client. - client := clientMachine.GetContainer(ctx, b) + client := clientMachine.GetNativeContainer(ctx, b) defer client.CleanUp(ctx) path := fmt.Sprintf("http://%s:%d/%s", ip, servingPort, doc) @@ -137,6 +137,7 @@ func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, doc st cmd = fmt.Sprintf("ab -n %d -c %d %s", requests, concurrency, path) b.ResetTimer() + server.RestartProfiles() for i := 0; i < b.N; i++ { out, err := client.Run(ctx, dockerutil.RunOpts{ Image: "benchmarks/ab", diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go index 664e0797e..a5e198e14 100644 --- a/test/benchmarks/network/iperf_test.go +++ b/test/benchmarks/network/iperf_test.go @@ -22,12 +22,13 @@ import ( "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/test/benchmarks/harness" ) func BenchmarkIperf(b *testing.B) { + const time = 10 // time in seconds to run the client. - // Get two machines clientMachine, err := h.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) @@ -39,30 +40,32 @@ func BenchmarkIperf(b *testing.B) { b.Fatalf("failed to get machine: %v", err) } defer serverMachine.CleanUp() - + ctx := context.Background() for _, bm := range []struct { - name string - clientRuntime string - serverRuntime string + name string + clientFunc func(context.Context, testutil.Logger) *dockerutil.Container + serverFunc func(context.Context, testutil.Logger) *dockerutil.Container }{ // We are either measuring the server or the client. The other should be // runc. e.g. Upload sees how fast the runtime under test uploads to a native // server. - {name: "Upload", clientRuntime: dockerutil.Runtime(), serverRuntime: "runc"}, - {name: "Download", clientRuntime: "runc", serverRuntime: dockerutil.Runtime()}, + { + name: "Upload", + clientFunc: clientMachine.GetContainer, + serverFunc: serverMachine.GetNativeContainer, + }, + { + name: "Download", + clientFunc: clientMachine.GetNativeContainer, + serverFunc: serverMachine.GetContainer, + }, } { b.Run(bm.name, func(b *testing.B) { - - // Get a container from the server and set its runtime. - ctx := context.Background() - server := serverMachine.GetContainer(ctx, b) + // Set up the containers. + server := bm.serverFunc(ctx, b) defer server.CleanUp(ctx) - server.Runtime = bm.serverRuntime - - // Get a container from the client and set its runtime. - client := clientMachine.GetContainer(ctx, b) + client := bm.clientFunc(ctx, b) defer client.CleanUp(ctx) - client.Runtime = bm.clientRuntime // iperf serves on port 5001 by default. port := 5001 @@ -91,11 +94,14 @@ func BenchmarkIperf(b *testing.B) { } // iperf report in Kb realtime - cmd := fmt.Sprintf("iperf -f K --realtime -c %s -p %d", ip.String(), servingPort) + cmd := fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", time, ip.String(), servingPort) // Run the client. b.ResetTimer() + // Restart the server profiles. If the server isn't being profiled + // this does nothing. + server.RestartProfiles() for i := 0; i < b.N; i++ { out, err := client.Run(ctx, dockerutil.RunOpts{ Image: "benchmarks/iperf", diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go index 1a0221893..74e1e6def 100644 --- a/test/packetimpact/runner/packetimpact_test.go +++ b/test/packetimpact/runner/packetimpact_test.go @@ -142,7 +142,7 @@ func TestOne(t *testing.T) { // Create the Docker container for the DUT. dut := dockerutil.MakeContainer(ctx, logger("dut")) if *dutPlatform == "linux" { - dut.Runtime = "" + dut = dockerutil.MakeNativeContainer(ctx, logger("dut")) } runOpts := dockerutil.RunOpts{ @@ -208,8 +208,7 @@ func TestOne(t *testing.T) { } // Create the Docker container for the testbench. - testbench := dockerutil.MakeContainer(ctx, logger("testbench")) - testbench.Runtime = "" // The testbench always runs on Linux. + testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) tbb := path.Base(*testbenchBinary) containerTestbenchBinary := "/packetimpact/" + tbb diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 87ce58c24..3af5f83fd 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -429,7 +429,6 @@ type Connection struct { layerStates []layerState injector Injector sniffer Sniffer - t *testing.T } // Returns the default incoming frame against which to match. If received is @@ -462,7 +461,9 @@ func (conn *Connection) match(override, received Layers) bool { } // Close frees associated resources held by the Connection. -func (conn *Connection) Close() { +func (conn *Connection) Close(t *testing.T) { + t.Helper() + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) for _, s := range conn.layerStates { if err := s.close(); err != nil { @@ -470,7 +471,7 @@ func (conn *Connection) Close() { } } if errs != nil { - conn.t.Fatalf("unable to close %+v: %s", conn, errs) + t.Fatalf("unable to close %+v: %s", conn, errs) } } @@ -482,7 +483,9 @@ func (conn *Connection) Close() { // overriden first. As an example, valid values of overrideLayers for a TCP- // over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and // [Ethernet, IPv4, TCP]. -func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...Layer) Layers { +func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers { + t.Helper() + var layersToSend Layers for i, s := range conn.layerStates { layer := s.outgoing() @@ -491,7 +494,7 @@ func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...L // end. if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 { if err := layer.merge(overrideLayers[j]); err != nil { - conn.t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) + t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) } } layersToSend = append(layersToSend, layer) @@ -505,21 +508,25 @@ func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...L // This method is useful for sending out-of-band control messages such as // ICMP packets, where it would not make sense to update the transport layer's // state using the ICMP header. -func (conn *Connection) SendFrameStateless(frame Layers) { +func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) { + t.Helper() + outBytes, err := frame.ToBytes() if err != nil { - conn.t.Fatalf("can't build outgoing packet: %s", err) + t.Fatalf("can't build outgoing packet: %s", err) } - conn.injector.Send(outBytes) + conn.injector.Send(t, outBytes) } // SendFrame sends a frame on the wire and updates the state of all layers. -func (conn *Connection) SendFrame(frame Layers) { +func (conn *Connection) SendFrame(t *testing.T, frame Layers) { + t.Helper() + outBytes, err := frame.ToBytes() if err != nil { - conn.t.Fatalf("can't build outgoing packet: %s", err) + t.Fatalf("can't build outgoing packet: %s", err) } - conn.injector.Send(outBytes) + conn.injector.Send(t, outBytes) // frame might have nil values where the caller wanted to use default values. // sentFrame will have no nil values in it because it comes from parsing the @@ -528,7 +535,7 @@ func (conn *Connection) SendFrame(frame Layers) { // Update the state of each layer based on what was sent. for i, s := range conn.layerStates { if err := s.sent(sentFrame[i]); err != nil { - conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) + t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) } } } @@ -538,18 +545,22 @@ func (conn *Connection) SendFrame(frame Layers) { // // Types defined with Connection as the underlying type should expose // type-safe versions of this method. -func (conn *Connection) send(overrideLayers Layers, additionalLayers ...Layer) { - conn.SendFrame(conn.CreateFrame(overrideLayers, additionalLayers...)) +func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + t.Helper() + + conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...)) } // recvFrame gets the next successfully parsed frame (of type Layers) within the // timeout provided. If no parsable frame arrives before the timeout, it returns // nil. -func (conn *Connection) recvFrame(timeout time.Duration) Layers { +func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers { + t.Helper() + if timeout <= 0 { return nil } - b := conn.sniffer.Recv(timeout) + b := conn.sniffer.Recv(t, timeout) if b == nil { return nil } @@ -569,32 +580,36 @@ func (e *layersError) Error() string { // Expect expects a frame with the final layerStates layer matching the // provided Layer within the timeout specified. If it doesn't arrive in time, // an error is returned. -func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) { +func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) { + t.Helper() + // Make a frame that will ignore all but the final layer. layers := make([]Layer, len(conn.layerStates)) layers[len(layers)-1] = layer - gotFrame, err := conn.ExpectFrame(layers, timeout) + gotFrame, err := conn.ExpectFrame(t, layers, timeout) if err != nil { return nil, err } if len(conn.layerStates)-1 < len(gotFrame) { return gotFrame[len(conn.layerStates)-1], nil } - conn.t.Fatal("the received frame should be at least as long as the expected layers") + t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame) panic("unreachable") } // ExpectFrame expects a frame that matches the provided Layers within the // timeout specified. If one arrives in time, the Layers is returned without an // error. If it doesn't arrive in time, it returns nil and error is non-nil. -func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) { +func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) { + t.Helper() + deadline := time.Now().Add(timeout) var errs error for { var gotLayers Layers if timeout = time.Until(deadline); timeout > 0 { - gotLayers = conn.recvFrame(timeout) + gotLayers = conn.recvFrame(t, timeout) } if gotLayers == nil { if errs == nil { @@ -605,7 +620,7 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer if conn.match(layers, gotLayers) { for i, s := range conn.layerStates { if err := s.received(gotLayers[i]); err != nil { - conn.t.Fatal(err) + t.Fatalf("failed to update test connection's layer states based on received frame: %s", err) } } return gotLayers, nil @@ -616,8 +631,10 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *Connection) Drain() { - conn.sniffer.Drain() +func (conn *Connection) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) } // TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. @@ -625,6 +642,8 @@ type TCPIPv4 Connection // NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make etherState: %s", err) @@ -650,57 +669,58 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { layerStates: []layerState{etherState, ipv4State, tcpState}, injector: injector, sniffer: sniffer, - t: t, } } // Connect performs a TCP 3-way handshake. The input Connection should have a // final TCP Layer. -func (conn *TCPIPv4) Connect() { - conn.t.Helper() +func (conn *TCPIPv4) Connect(t *testing.T) { + t.Helper() // Send the SYN. - conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. - synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { - conn.t.Fatalf("didn't get synack during handshake: %s", err) + t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. - conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) } // ConnectWithOptions performs a TCP 3-way handshake with given TCP options. // The input Connection should have a final TCP Layer. -func (conn *TCPIPv4) ConnectWithOptions(options []byte) { - conn.t.Helper() +func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) { + t.Helper() // Send the SYN. - conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) // Wait for the SYN-ACK. - synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { - conn.t.Fatalf("didn't get synack during handshake: %s", err) + t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. - conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) } // ExpectData is a convenient method that expects a Layer and the Layer after // it. If it doens't arrive in time, it returns nil. -func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { +func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + expected := make([]Layer, len(conn.layerStates)) expected[len(expected)-1] = tcp if payload != nil { expected = append(expected, payload) } - return (*Connection)(conn).ExpectFrame(expected, timeout) + return (*Connection)(conn).ExpectFrame(t, expected, timeout) } // ExpectNextData attempts to receive the next incoming segment for the @@ -709,9 +729,11 @@ func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duratio // It differs from ExpectData() in that here we are only interested in the next // received segment, while ExpectData() can receive multiple segments for the // connection until there is a match with given layers or a timeout. -func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { +func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + // Receive the first incoming TCP segment for this connection. - got, err := conn.ExpectData(&TCP{}, nil, timeout) + got, err := conn.ExpectData(t, &TCP{}, nil, timeout) if err != nil { return nil, err } @@ -720,7 +742,7 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur expected[len(expected)-1] = tcp if payload != nil { expected = append(expected, payload) - tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length())) + tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length())) } if !(*Connection)(conn).match(expected, got) { return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) @@ -730,71 +752,91 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur // Send a packet with reasonable defaults. Potentially override the TCP layer in // the connection with the provided layer and add additionLayers. -func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&tcp}, additionalLayers...) +func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&tcp}, additionalLayers...) } // Close frees associated resources held by the TCPIPv4 connection. -func (conn *TCPIPv4) Close() { - (*Connection)(conn).Close() +func (conn *TCPIPv4) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // Expect expects a frame with the TCP layer matching the provided TCP within // the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { - layer, err := (*Connection)(conn).Expect(&tcp, timeout) +func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &tcp, timeout) if layer == nil { return nil, err } gotTCP, ok := layer.(*TCP) if !ok { - conn.t.Fatalf("expected %s to be TCP", layer) + t.Fatalf("expected %s to be TCP", layer) } return gotTCP, err } -func (conn *TCPIPv4) tcpState() *tcpState { +func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState { + t.Helper() + state, ok := conn.layerStates[2].(*tcpState) if !ok { - conn.t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) + t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) } return state } -func (conn *TCPIPv4) ipv4State() *ipv4State { +func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State { + t.Helper() + state, ok := conn.layerStates[1].(*ipv4State) if !ok { - conn.t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) } return state } // RemoteSeqNum returns the next expected sequence number from the DUT. -func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { - return conn.tcpState().remoteSeqNum +func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value { + t.Helper() + + return conn.tcpState(t).remoteSeqNum } // LocalSeqNum returns the next sequence number to send from the testbench. -func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { - return conn.tcpState().localSeqNum +func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value { + t.Helper() + + return conn.tcpState(t).localSeqNum } // SynAck returns the SynAck that was part of the handshake. -func (conn *TCPIPv4) SynAck() *TCP { - return conn.tcpState().synAck +func (conn *TCPIPv4) SynAck(t *testing.T) *TCP { + t.Helper() + + return conn.tcpState(t).synAck } // LocalAddr gets the local socket address of this connection. -func (conn *TCPIPv4) LocalAddr() *unix.SockaddrInet4 { - sa := &unix.SockaddrInet4{Port: int(*conn.tcpState().out.SrcPort)} - copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr) +func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { + t.Helper() + + sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) return sa } // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *TCPIPv4) Drain() { - conn.sniffer.Drain() +func (conn *TCPIPv4) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) } // IPv6Conn maintains the state for all the layers in a IPv6 connection. @@ -802,6 +844,8 @@ type IPv6Conn Connection // NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults. func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make EtherState: %s", err) @@ -824,25 +868,30 @@ func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { layerStates: []layerState{etherState, ipv6State}, injector: injector, sniffer: sniffer, - t: t, } } // Send sends a frame with ipv6 overriding the IPv6 layer defaults and // additionalLayers added after it. -func (conn *IPv6Conn) Send(ipv6 IPv6, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&ipv6}, additionalLayers...) +func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ipv6}, additionalLayers...) } // Close to clean up any resources held. -func (conn *IPv6Conn) Close() { - (*Connection)(conn).Close() +func (conn *IPv6Conn) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // ExpectFrame expects a frame that matches the provided Layers within the // timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) { - return (*Connection)(conn).ExpectFrame(frame, timeout) +func (conn *IPv6Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(conn).ExpectFrame(t, frame, timeout) } // UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. @@ -850,6 +899,8 @@ type UDPIPv4 Connection // NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make etherState: %s", err) @@ -875,81 +926,96 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { layerStates: []layerState{etherState, ipv4State, udpState}, injector: injector, sniffer: sniffer, - t: t, } } -func (conn *UDPIPv4) udpState() *udpState { +func (conn *UDPIPv4) udpState(t *testing.T) *udpState { + t.Helper() + state, ok := conn.layerStates[2].(*udpState) if !ok { - conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) } return state } -func (conn *UDPIPv4) ipv4State() *ipv4State { +func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State { + t.Helper() + state, ok := conn.layerStates[1].(*ipv4State) if !ok { - conn.t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) } return state } // LocalAddr gets the local socket address of this connection. -func (conn *UDPIPv4) LocalAddr() *unix.SockaddrInet4 { - sa := &unix.SockaddrInet4{Port: int(*conn.udpState().out.SrcPort)} - copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr) +func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { + t.Helper() + + sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) return sa } // Send sends a packet with reasonable defaults, potentially overriding the UDP // layer and adding additionLayers. -func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&udp}, additionalLayers...) +func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...) } // SendIP sends a packet with reasonable defaults, potentially overriding the // UDP and IPv4 headers and adding additionLayers. -func (conn *UDPIPv4) SendIP(ip IPv4, udp UDP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...) +func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) } // Expect expects a frame with the UDP layer matching the provided UDP within // the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) { - conn.t.Helper() - layer, err := (*Connection)(conn).Expect(&udp, timeout) +func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &udp, timeout) if err != nil { return nil, err } gotUDP, ok := layer.(*UDP) if !ok { - conn.t.Fatalf("expected %s to be UDP", layer) + t.Fatalf("expected %s to be UDP", layer) } return gotUDP, nil } // ExpectData is a convenient method that expects a Layer and the Layer after // it. If it doens't arrive in time, it returns nil. -func (conn *UDPIPv4) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) { - conn.t.Helper() +func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + t.Helper() + expected := make([]Layer, len(conn.layerStates)) expected[len(expected)-1] = &udp if payload.length() != 0 { expected = append(expected, &payload) } - return (*Connection)(conn).ExpectFrame(expected, timeout) + return (*Connection)(conn).ExpectFrame(t, expected, timeout) } // Close frees associated resources held by the UDPIPv4 connection. -func (conn *UDPIPv4) Close() { - (*Connection)(conn).Close() +func (conn *UDPIPv4) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *UDPIPv4) Drain() { - conn.sniffer.Drain() +func (conn *UDPIPv4) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) } // UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection. @@ -957,6 +1023,8 @@ type UDPIPv6 Connection // NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults. func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make etherState: %s", err) @@ -981,86 +1049,101 @@ func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 { layerStates: []layerState{etherState, ipv6State, udpState}, injector: injector, sniffer: sniffer, - t: t, } } -func (conn *UDPIPv6) udpState() *udpState { +func (conn *UDPIPv6) udpState(t *testing.T) *udpState { + t.Helper() + state, ok := conn.layerStates[2].(*udpState) if !ok { - conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) } return state } -func (conn *UDPIPv6) ipv6State() *ipv6State { +func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State { + t.Helper() + state, ok := conn.layerStates[1].(*ipv6State) if !ok { - conn.t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1]) + t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1]) } return state } // LocalAddr gets the local socket address of this connection. -func (conn *UDPIPv6) LocalAddr() *unix.SockaddrInet6 { +func (conn *UDPIPv6) LocalAddr(t *testing.T) *unix.SockaddrInet6 { + t.Helper() + sa := &unix.SockaddrInet6{ - Port: int(*conn.udpState().out.SrcPort), + Port: int(*conn.udpState(t).out.SrcPort), // Local address is in perspective to the remote host, so it's scoped to the // ID of the remote interface. ZoneId: uint32(RemoteInterfaceID), } - copy(sa.Addr[:], *conn.ipv6State().out.SrcAddr) + copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr) return sa } // Send sends a packet with reasonable defaults, potentially overriding the UDP // layer and adding additionLayers. -func (conn *UDPIPv6) Send(udp UDP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&udp}, additionalLayers...) +func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...) } // SendIPv6 sends a packet with reasonable defaults, potentially overriding the // UDP and IPv6 headers and adding additionLayers. -func (conn *UDPIPv6) SendIPv6(ip IPv6, udp UDP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...) +func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) } // Expect expects a frame with the UDP layer matching the provided UDP within // the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *UDPIPv6) Expect(udp UDP, timeout time.Duration) (*UDP, error) { - conn.t.Helper() - layer, err := (*Connection)(conn).Expect(&udp, timeout) +func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &udp, timeout) if err != nil { return nil, err } gotUDP, ok := layer.(*UDP) if !ok { - conn.t.Fatalf("expected %s to be UDP", layer) + t.Fatalf("expected %s to be UDP", layer) } return gotUDP, nil } // ExpectData is a convenient method that expects a Layer and the Layer after // it. If it doens't arrive in time, it returns nil. -func (conn *UDPIPv6) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) { - conn.t.Helper() +func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + t.Helper() + expected := make([]Layer, len(conn.layerStates)) expected[len(expected)-1] = &udp if payload.length() != 0 { expected = append(expected, &payload) } - return (*Connection)(conn).ExpectFrame(expected, timeout) + return (*Connection)(conn).ExpectFrame(t, expected, timeout) } // Close frees associated resources held by the UDPIPv6 connection. -func (conn *UDPIPv6) Close() { - (*Connection)(conn).Close() +func (conn *UDPIPv6) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *UDPIPv6) Drain() { - conn.sniffer.Drain() +func (conn *UDPIPv6) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) } // TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection. @@ -1093,7 +1176,6 @@ func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 { layerStates: []layerState{etherState, ipv6State, tcpState}, injector: injector, sniffer: sniffer, - t: t, } } @@ -1104,16 +1186,20 @@ func (conn *TCPIPv6) SrcPort() uint16 { // ExpectData is a convenient method that expects a Layer and the Layer after // it. If it doens't arrive in time, it returns nil. -func (conn *TCPIPv6) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { +func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + expected := make([]Layer, len(conn.layerStates)) expected[len(expected)-1] = tcp if payload != nil { expected = append(expected, payload) } - return (*Connection)(conn).ExpectFrame(expected, timeout) + return (*Connection)(conn).ExpectFrame(t, expected, timeout) } // Close frees associated resources held by the TCPIPv6 connection. -func (conn *TCPIPv6) Close() { - (*Connection)(conn).Close() +func (conn *TCPIPv6) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 51be13759..73c532e75 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -31,13 +31,14 @@ import ( // DUT communicates with the DUT to force it to make POSIX calls. type DUT struct { - t *testing.T conn *grpc.ClientConn posixServer POSIXClient } // NewDUT creates a new connection with the DUT over gRPC. func NewDUT(t *testing.T) DUT { + t.Helper() + flag.Parse() if err := genPseudoFlags(); err != nil { t.Fatal("generating psuedo flags:", err) @@ -50,7 +51,6 @@ func NewDUT(t *testing.T) DUT { } posixServer := NewPOSIXClient(conn) return DUT{ - t: t, conn: conn, posixServer: posixServer, } @@ -61,8 +61,9 @@ func (dut *DUT) TearDown() { dut.conn.Close() } -func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { - dut.t.Helper() +func (dut *DUT) sockaddrToProto(t *testing.T, sa unix.Sockaddr) *pb.Sockaddr { + t.Helper() + switch s := sa.(type) { case *unix.SockaddrInet4: return &pb.Sockaddr{ @@ -87,12 +88,13 @@ func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { }, } } - dut.t.Fatalf("can't parse Sockaddr struct: %+v", sa) + t.Fatalf("can't parse Sockaddr struct: %+v", sa) return nil } -func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { - dut.t.Helper() +func (dut *DUT) protoToSockaddr(t *testing.T, sa *pb.Sockaddr) unix.Sockaddr { + t.Helper() + switch s := sa.Sockaddr.(type) { case *pb.Sockaddr_In: ret := unix.SockaddrInet4{ @@ -108,31 +110,32 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { copy(ret.Addr[:], s.In6.GetAddr()) return &ret } - dut.t.Fatalf("can't parse Sockaddr proto: %+v", sa) + t.Fatalf("can't parse Sockaddr proto: %#v", sa) return nil } // CreateBoundSocket makes a new socket on the DUT, with type typ and protocol // proto, and bound to the IP address addr. Returns the new file descriptor and // the port that was selected on the DUT. -func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) { - dut.t.Helper() +func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (int32, uint16) { + t.Helper() + var fd int32 if addr.To4() != nil { - fd = dut.Socket(unix.AF_INET, typ, proto) + fd = dut.Socket(t, unix.AF_INET, typ, proto) sa := unix.SockaddrInet4{} copy(sa.Addr[:], addr.To4()) - dut.Bind(fd, &sa) + dut.Bind(t, fd, &sa) } else if addr.To16() != nil { - fd = dut.Socket(unix.AF_INET6, typ, proto) + fd = dut.Socket(t, unix.AF_INET6, typ, proto) sa := unix.SockaddrInet6{} copy(sa.Addr[:], addr.To16()) sa.ZoneId = uint32(RemoteInterfaceID) - dut.Bind(fd, &sa) + dut.Bind(t, fd, &sa) } else { - dut.t.Fatalf("unknown ip addr type for remoteIP") + t.Fatalf("invalid IP address: %s", addr) } - sa := dut.GetSockName(fd) + sa := dut.GetSockName(t, fd) var port int switch s := sa.(type) { case *unix.SockaddrInet4: @@ -140,15 +143,17 @@ func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) case *unix.SockaddrInet6: port = s.Port default: - dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa) + t.Fatalf("unknown sockaddr type from getsockname: %T", sa) } return fd, uint16(port) } // CreateListener makes a new TCP connection. If it fails, the test ends. -func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { - fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4)) - dut.Listen(fd, backlog) +func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) { + t.Helper() + + fd, remotePort := dut.CreateBoundSocket(t, typ, proto, net.ParseIP(RemoteIPv4)) + dut.Listen(t, fd, backlog) return fd, remotePort } @@ -158,53 +163,57 @@ func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { // Accept calls accept on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // AcceptWithErrno. -func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - fd, sa, err := dut.AcceptWithErrno(ctx, sockfd) + fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd) if fd < 0 { - dut.t.Fatalf("failed to accept: %s", err) + t.Fatalf("failed to accept: %s", err) } return fd, sa } // AcceptWithErrno calls accept on the DUT. -func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { - dut.t.Helper() +func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { + t.Helper() + req := pb.AcceptRequest{ Sockfd: sockfd, } resp, err := dut.posixServer.Accept(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Accept: %s", err) + t.Fatalf("failed to call Accept: %s", err) } - return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) } // Bind calls bind on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is // needed, use BindWithErrno. -func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.BindWithErrno(ctx, fd, sa) + ret, err := dut.BindWithErrno(ctx, t, fd, sa) if ret != 0 { - dut.t.Fatalf("failed to bind socket: %s", err) + t.Fatalf("failed to bind socket: %s", err) } } // BindWithErrno calls bind on the DUT. -func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.BindRequest{ Sockfd: fd, - Addr: dut.sockaddrToProto(sa), + Addr: dut.sockaddrToProto(t, sa), } resp, err := dut.posixServer.Bind(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Bind: %s", err) + t.Fatalf("failed to call Bind: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -212,25 +221,27 @@ func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) ( // Close calls close on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // CloseWithErrno. -func (dut *DUT) Close(fd int32) { - dut.t.Helper() +func (dut *DUT) Close(t *testing.T, fd int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.CloseWithErrno(ctx, fd) + ret, err := dut.CloseWithErrno(ctx, t, fd) if ret != 0 { - dut.t.Fatalf("failed to close: %s", err) + t.Fatalf("failed to close: %s", err) } } // CloseWithErrno calls close on the DUT. -func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int32, error) { + t.Helper() + req := pb.CloseRequest{ Fd: fd, } resp, err := dut.posixServer.Close(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Close: %s", err) + t.Fatalf("failed to call Close: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -238,28 +249,30 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { // Connect calls connect on the DUT and causes a fatal test failure if it // doesn't succeed. If more control over the timeout or error handling is // needed, use ConnectWithErrno. -func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.ConnectWithErrno(ctx, fd, sa) + ret, err := dut.ConnectWithErrno(ctx, t, fd, sa) // Ignore 'operation in progress' error that can be returned when the socket // is non-blocking. if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 { - dut.t.Fatalf("failed to connect socket: %s", err) + t.Fatalf("failed to connect socket: %s", err) } } // ConnectWithErrno calls bind on the DUT. -func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.ConnectRequest{ Sockfd: fd, - Addr: dut.sockaddrToProto(sa), + Addr: dut.sockaddrToProto(t, sa), } resp, err := dut.posixServer.Connect(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Connect: %s", err) + t.Fatalf("failed to call Connect: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -267,20 +280,22 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr // Fcntl calls fcntl on the DUT and causes a fatal test failure if it // doesn't succeed. If more control over the timeout or error handling is // needed, use FcntlWithErrno. -func (dut *DUT) Fcntl(fd, cmd, arg int32) int32 { - dut.t.Helper() +func (dut *DUT) Fcntl(t *testing.T, fd, cmd, arg int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.FcntlWithErrno(ctx, fd, cmd, arg) + ret, err := dut.FcntlWithErrno(ctx, t, fd, cmd, arg) if ret == -1 { - dut.t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) + t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) } return ret } // FcntlWithErrno calls fcntl on the DUT. -func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) FcntlWithErrno(ctx context.Context, t *testing.T, fd, cmd, arg int32) (int32, error) { + t.Helper() + req := pb.FcntlRequest{ Fd: fd, Cmd: cmd, @@ -288,7 +303,7 @@ func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, } resp, err := dut.posixServer.Fcntl(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Fcntl: %s", err) + t.Fatalf("failed to call Fcntl: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -296,32 +311,35 @@ func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, // GetSockName calls getsockname on the DUT and causes a fatal test failure if // it doesn't succeed. If more control over the timeout or error handling is // needed, use GetSockNameWithErrno. -func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr { - dut.t.Helper() +func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd) + ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd) if ret != 0 { - dut.t.Fatalf("failed to getsockname: %s", err) + t.Fatalf("failed to getsockname: %s", err) } return sa } // GetSockNameWithErrno calls getsockname on the DUT. -func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { - dut.t.Helper() +func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { + t.Helper() + req := pb.GetSockNameRequest{ Sockfd: sockfd, } resp, err := dut.posixServer.GetSockName(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Bind: %s", err) + t.Fatalf("failed to call Bind: %s", err) } - return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) } -func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { - dut.t.Helper() +func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { + t.Helper() + req := pb.GetSockOptRequest{ Sockfd: sockfd, Level: level, @@ -331,11 +349,11 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i } resp, err := dut.posixServer.GetSockOpt(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call GetSockOpt: %s", err) + t.Fatalf("failed to call GetSockOpt: %s", err) } optval := resp.GetOptval() if optval == nil { - dut.t.Fatalf("GetSockOpt response does not contain a value") + t.Fatalf("GetSockOpt response does not contain a value") } return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_()) } @@ -345,13 +363,14 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i // needed, use GetSockOptWithErrno. Because endianess and the width of values // might differ between the testbench and DUT architectures, prefer to use a // more specific GetSockOptXxx function. -func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte { - dut.t.Helper() +func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen) + ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen) if ret != 0 { - dut.t.Fatalf("failed to GetSockOpt: %s", err) + t.Fatalf("failed to GetSockOpt: %s", err) } return optval } @@ -359,12 +378,13 @@ func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte { // GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the // width of values might differ between the testbench and DUT architectures, // prefer to use a more specific GetSockOptXxxWithErrno function. -func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) +func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32) (int32, []byte, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want bytes", optval) + t.Fatalf("GetSockOpt got value type: %T, want bytes", optval.Val) } return ret, bytesval.Bytesval, errno } @@ -372,24 +392,26 @@ func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, // GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the int optval or error handling // is needed, use GetSockOptIntWithErrno. -func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 { - dut.t.Helper() +func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname) + ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname) if ret != 0 { - dut.t.Fatalf("failed to GetSockOptInt: %s", err) + t.Fatalf("failed to GetSockOptInt: %s", err) } return intval } // GetSockOptIntWithErrno calls getsockopt with an integer optval. -func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) +func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, int32, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) intval, ok := optval.Val.(*pb.SockOptVal_Intval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want int", optval) + t.Fatalf("GetSockOpt got value type: %T, want int", optval.Val) } return ret, intval.Intval, errno } @@ -397,24 +419,26 @@ func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna // GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the timeout or error handling is // needed, use GetSockOptTimevalWithErrno. -func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval { - dut.t.Helper() +func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname) + ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname) if ret != 0 { - dut.t.Fatalf("failed to GetSockOptTimeval: %s", err) + t.Fatalf("failed to GetSockOptTimeval: %s", err) } return timeval } // GetSockOptTimevalWithErrno calls getsockopt and returns a timeval. -func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) +func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, unix.Timeval, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) tv, ok := optval.Val.(*pb.SockOptVal_Timeval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want timeval", optval) + t.Fatalf("GetSockOpt got value type: %T, want timeval", optval.Val) } timeval := unix.Timeval{ Sec: tv.Timeval.Seconds, @@ -426,26 +450,28 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, o // Listen calls listen on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // ListenWithErrno. -func (dut *DUT) Listen(sockfd, backlog int32) { - dut.t.Helper() +func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.ListenWithErrno(ctx, sockfd, backlog) + ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog) if ret != 0 { - dut.t.Fatalf("failed to listen: %s", err) + t.Fatalf("failed to listen: %s", err) } } // ListenWithErrno calls listen on the DUT. -func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backlog int32) (int32, error) { + t.Helper() + req := pb.ListenRequest{ Sockfd: sockfd, Backlog: backlog, } resp, err := dut.posixServer.Listen(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Listen: %s", err) + t.Fatalf("failed to call Listen: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -453,20 +479,22 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int // Send calls send on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // SendWithErrno. -func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 { - dut.t.Helper() +func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags) + ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags) if ret == -1 { - dut.t.Fatalf("failed to send: %s", err) + t.Fatalf("failed to send: %s", err) } return ret } // SendWithErrno calls send on the DUT. -func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32) (int32, error) { + t.Helper() + req := pb.SendRequest{ Sockfd: sockfd, Buf: buf, @@ -474,7 +502,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla } resp, err := dut.posixServer.Send(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Send: %s", err) + t.Fatalf("failed to call Send: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -482,48 +510,52 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla // SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // SendToWithErrno. -func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { - dut.t.Helper() +func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr) + ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr) if ret == -1 { - dut.t.Fatalf("failed to sendto: %s", err) + t.Fatalf("failed to sendto: %s", err) } return ret } // SendToWithErrno calls sendto on the DUT. -func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.SendToRequest{ Sockfd: sockfd, Buf: buf, Flags: flags, - DestAddr: dut.sockaddrToProto(destAddr), + DestAddr: dut.sockaddrToProto(t, destAddr), } resp, err := dut.posixServer.SendTo(ctx, &req) if err != nil { - dut.t.Fatalf("faled to call SendTo: %s", err) + t.Fatalf("faled to call SendTo: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } // SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking // is true, otherwise it will clear the flag. -func (dut *DUT) SetNonBlocking(fd int32, nonblocking bool) { - dut.t.Helper() - flags := dut.Fcntl(fd, unix.F_GETFL, 0) +func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) { + t.Helper() + + flags := dut.Fcntl(t, fd, unix.F_GETFL, 0) if nonblocking { flags |= unix.O_NONBLOCK } else { flags &= ^unix.O_NONBLOCK } - dut.Fcntl(fd, unix.F_SETFL, flags) + dut.Fcntl(t, fd, unix.F_SETFL, flags) } -func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { - dut.t.Helper() +func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { + t.Helper() + req := pb.SetSockOptRequest{ Sockfd: sockfd, Level: level, @@ -532,7 +564,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op } resp, err := dut.posixServer.SetSockOpt(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call SetSockOpt: %s", err) + t.Fatalf("failed to call SetSockOpt: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -542,81 +574,89 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op // needed, use SetSockOptWithErrno. Because endianess and the width of values // might differ between the testbench and DUT architectures, prefer to use a // more specific SetSockOptXxx function. -func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) { - dut.t.Helper() +func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval) + ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval) if ret != 0 { - dut.t.Fatalf("failed to SetSockOpt: %s", err) + t.Fatalf("failed to SetSockOpt: %s", err) } } // SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the // width of values might differ between the testbench and DUT architectures, // prefer to use a more specific SetSockOptXxxWithErrno function. -func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) { - dut.t.Helper() - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) +func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval []byte) (int32, error) { + t.Helper() + + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) } // SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the int optval or error handling // is needed, use SetSockOptIntWithErrno. -func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) { - dut.t.Helper() +func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval) + ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval) if ret != 0 { - dut.t.Fatalf("failed to SetSockOptInt: %s", err) + t.Fatalf("failed to SetSockOptInt: %s", err) } } // SetSockOptIntWithErrno calls setsockopt with an integer optval. -func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) { - dut.t.Helper() - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) +func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optval int32) (int32, error) { + t.Helper() + + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) } // SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the timeout or error handling is // needed, use SetSockOptTimevalWithErrno. -func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) { - dut.t.Helper() +func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv) + ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv) if ret != 0 { - dut.t.Fatalf("failed to SetSockOptTimeval: %s", err) + t.Fatalf("failed to SetSockOptTimeval: %s", err) } } // SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to // bytes. -func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { - dut.t.Helper() +func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { + t.Helper() + timeval := pb.Timeval{ Seconds: int64(tv.Sec), Microseconds: int64(tv.Usec), } - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) } // Socket calls socket on the DUT and returns the file descriptor. If socket // fails on the DUT, the test ends. -func (dut *DUT) Socket(domain, typ, proto int32) int32 { - dut.t.Helper() - fd, err := dut.SocketWithErrno(domain, typ, proto) +func (dut *DUT) Socket(t *testing.T, domain, typ, proto int32) int32 { + t.Helper() + + fd, err := dut.SocketWithErrno(t, domain, typ, proto) if fd < 0 { - dut.t.Fatalf("failed to create socket: %s", err) + t.Fatalf("failed to create socket: %s", err) } return fd } // SocketWithErrno calls socket on the DUT and returns the fd and errno. -func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, error) { + t.Helper() + req := pb.SocketRequest{ Domain: domain, Type: typ, @@ -625,7 +665,7 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { ctx := context.Background() resp, err := dut.posixServer.Socket(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Socket: %s", err) + t.Fatalf("failed to call Socket: %s", err) } return resp.GetFd(), syscall.Errno(resp.GetErrno_()) } @@ -633,20 +673,22 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { // Recv calls recv on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // RecvWithErrno. -func (dut *DUT) Recv(sockfd, len, flags int32) []byte { - dut.t.Helper() +func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags) + ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags) if ret == -1 { - dut.t.Fatalf("failed to recv: %s", err) + t.Fatalf("failed to recv: %s", err) } return buf } // RecvWithErrno calls recv on the DUT. -func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) { - dut.t.Helper() +func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, flags int32) (int32, []byte, error) { + t.Helper() + req := pb.RecvRequest{ Sockfd: sockfd, Len: len, @@ -654,7 +696,7 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (in } resp, err := dut.posixServer.Recv(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Recv: %s", err) + t.Fatalf("failed to call Recv: %s", err) } return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) } diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 278229b7e..57e822725 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -28,7 +28,6 @@ import ( // Sniffer can sniff raw packets on the wire. type Sniffer struct { - t *testing.T fd int } @@ -40,6 +39,8 @@ func htons(x uint16) uint16 { // NewSniffer creates a Sniffer connected to *device. func NewSniffer(t *testing.T) (Sniffer, error) { + t.Helper() + snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) if err != nil { return Sniffer{}, err @@ -51,7 +52,6 @@ func NewSniffer(t *testing.T) (Sniffer, error) { t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err) } return Sniffer{ - t: t, fd: snifferFd, }, nil } @@ -61,7 +61,9 @@ func NewSniffer(t *testing.T) (Sniffer, error) { const maxReadSize int = 65536 // Recv tries to read one frame until the timeout is up. -func (s *Sniffer) Recv(timeout time.Duration) []byte { +func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte { + t.Helper() + deadline := time.Now().Add(timeout) for { timeout = deadline.Sub(time.Now()) @@ -75,7 +77,7 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { } if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { - s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) + t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) } buf := make([]byte, maxReadSize) @@ -85,10 +87,10 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { continue } if err != nil { - s.t.Fatalf("can't read: %s", err) + t.Fatalf("can't read: %s", err) } if nread > maxReadSize { - s.t.Fatalf("received a truncated frame of %d bytes", nread) + t.Fatalf("received a truncated frame of %d bytes, want at most %d bytes", nread, maxReadSize) } return buf[:nread] } @@ -96,14 +98,16 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { // Drain drains the Sniffer's socket receive buffer by receiving until there's // nothing else to receive. -func (s *Sniffer) Drain() { - s.t.Helper() +func (s *Sniffer) Drain(t *testing.T) { + t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) if err != nil { - s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + t.Fatalf("failed to get sniffer socket fd flags: %s", err) } - if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { - s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + nonBlockingFlags := flags | unix.O_NONBLOCK + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, nonBlockingFlags); err != nil { + t.Fatalf("failed to make sniffer socket non-blocking with flags %b: %s", nonBlockingFlags, err) } for { buf := make([]byte, maxReadSize) @@ -113,7 +117,7 @@ func (s *Sniffer) Drain() { } } if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { - s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + t.Fatalf("failed to restore sniffer socket fd flags to %b: %s", flags, err) } } @@ -128,12 +132,13 @@ func (s *Sniffer) close() error { // Injector can inject raw frames. type Injector struct { - t *testing.T fd int } // NewInjector creates a new injector on *device. func NewInjector(t *testing.T) (Injector, error) { + t.Helper() + ifInfo, err := net.InterfaceByName(Device) if err != nil { return Injector{}, err @@ -156,15 +161,20 @@ func NewInjector(t *testing.T) (Injector, error) { return Injector{}, err } return Injector{ - t: t, fd: injectFd, }, nil } // Send a raw frame. -func (i *Injector) Send(b []byte) { - if _, err := unix.Write(i.fd, b); err != nil { - i.t.Fatalf("can't write: %s of len %d", err, len(b)) +func (i *Injector) Send(t *testing.T, b []byte) { + t.Helper() + + n, err := unix.Write(i.fd, b) + if err != nil { + t.Fatalf("can't write bytes of len %d: %s", len(b), err) + } + if n != len(b) { + t.Fatalf("got %d bytes written, want %d", n, len(b)) } } diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 407565078..a61054c2c 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -39,34 +39,34 @@ func TestFinWait2Timeout(t *testing.T) { t.Run(tt.description, func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() + defer conn.Close(t) + conn.Connect(t) - acceptFd, _ := dut.Accept(listenFd) + acceptFd, _ := dut.Accept(t, listenFd) if tt.linger2 { tv := unix.Timeval{Sec: 1, Usec: 0} - dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) + dut.SetSockOptTimeval(t, acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) } - dut.Close(acceptFd) + dut.Close(t, acceptFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) time.Sleep(5 * time.Second) - conn.Drain() + conn.Drain(t) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) if tt.linger2 { - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { t.Fatalf("expected a RST packet within a second but got none: %s", err) } } else { - if got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { + if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { t.Fatalf("expected no RST packets within ten seconds but got one: %s", got) } } diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go index 8dfd26ee8..2d59d552d 100644 --- a/test/packetimpact/tests/icmpv6_param_problem_test.go +++ b/test/packetimpact/tests/icmpv6_param_problem_test.go @@ -34,7 +34,7 @@ func TestICMPv6ParamProblemTest(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) - defer conn.Close() + defer conn.Close(t) ipv6 := testbench.IPv6{ // 254 is reserved and used for experimentation and testing. This should // cause an error. @@ -45,8 +45,8 @@ func TestICMPv6ParamProblemTest(t *testing.T) { Payload: []byte("hello world"), } - toSend := (*testbench.Connection)(&conn).CreateFrame(testbench.Layers{&ipv6}, &icmpv6) - (*testbench.Connection)(&conn).SendFrame(toSend) + toSend := (*testbench.Connection)(&conn).CreateFrame(t, testbench.Layers{&ipv6}, &icmpv6) + (*testbench.Connection)(&conn).SendFrame(t, toSend) // Build the expected ICMPv6 payload, which includes an index to the // problematic byte and also the problematic packet as described in @@ -72,7 +72,7 @@ func TestICMPv6ParamProblemTest(t *testing.T) { &expectedICMPv6, } timeout := time.Second - if _, err := conn.ExpectFrame(paramProblem, timeout); err != nil { + if _, err := conn.ExpectFrame(t, paramProblem, timeout); err != nil { t.Errorf("expected %s within %s but got none: %s", paramProblem, timeout, err) } } diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go index 70f6df5e0..cf881418c 100644 --- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go +++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go @@ -31,8 +31,8 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func recvTCPSegment(conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) { - layers, err := conn.ExpectData(expect, expectPayload, time.Second) +func recvTCPSegment(t *testing.T, conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) { + layers, err := conn.ExpectData(t, expect, expectPayload, time.Second) if err != nil { return 0, fmt.Errorf("failed to receive TCP segment: %s", err) } @@ -69,17 +69,17 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - remoteFD, _ := dut.Accept(listenFD) - defer dut.Close(remoteFD) + conn.Connect(t) + remoteFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, remoteFD) - dut.SetSockOptInt(remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) // TODO(b/129291778) The following socket option clears the DF bit on // IP packets sent over the socket, and is currently not supported by @@ -87,30 +87,30 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { // socket option being not supported does not affect the operation of // this test. Once the socket option is supported, the following call // can be changed to simply assert success. - ret, errno := dut.SetSockOptIntWithErrno(context.Background(), remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) + ret, errno := dut.SetSockOptIntWithErrno(context.Background(), t, remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) if ret == -1 && errno != unix.ENOTSUP { t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno) } samplePayload := &testbench.Payload{Bytes: tc.payload} - dut.Send(remoteFD, tc.payload, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, remoteFD, tc.payload, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err) } // Let the DUT estimate RTO with RTT from the DATA-ACK. // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which // we can skip sending this ACK. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - dut.Send(remoteFD, tc.payload, 0) - expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum()))} - originalID, err := recvTCPSegment(&conn, expectTCP, samplePayload) + dut.Send(t, remoteFD, tc.payload, 0) + expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))} + originalID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload) if err != nil { t.Fatalf("failed to receive TCP segment: %s", err) } - retransmitID, err := recvTCPSegment(&conn, expectTCP, samplePayload) + retransmitID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload) if err != nil { t.Fatalf("failed to receive retransmitted TCP segment: %s", err) } diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go index 7b462c8e2..b5f94ad4b 100644 --- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go +++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go @@ -48,7 +48,7 @@ func TestIPv6FragmentReassembly(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) - defer conn.Close() + defer conn.Close(t) firstPayloadToSend := make([]byte, firstPayloadLength) for i := range firstPayloadToSend { @@ -81,7 +81,7 @@ func TestIPv6FragmentReassembly(t *testing.T) { buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}), ) - conn.Send(testbench.IPv6{}, + conn.Send(t, testbench.IPv6{}, &testbench.IPv6FragmentExtHdr{ FragmentOffset: testbench.Uint16(0), MoreFragments: testbench.Bool(true), @@ -96,7 +96,7 @@ func TestIPv6FragmentReassembly(t *testing.T) { icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber) - conn.Send(testbench.IPv6{}, + conn.Send(t, testbench.IPv6{}, &testbench.IPv6FragmentExtHdr{ NextHeader: &icmpv6ProtoNum, FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8), @@ -107,7 +107,7 @@ func TestIPv6FragmentReassembly(t *testing.T) { Bytes: secondPayloadToSend, }) - gotEchoReplyFirstPart, err := conn.ExpectFrame(testbench.Layers{ + gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{ &testbench.Ether{}, &testbench.IPv6{}, &testbench.IPv6FragmentExtHdr{ @@ -142,7 +142,7 @@ func TestIPv6FragmentReassembly(t *testing.T) { hex.Dump(wantFirstPayload)) } - gotEchoReplySecondPart, err := conn.ExpectFrame(testbench.Layers{ + gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{ &testbench.Ether{}, &testbench.IPv6{}, &testbench.IPv6FragmentExtHdr{ diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go index 100b30ad7..d7d63cbd2 100644 --- a/test/packetimpact/tests/ipv6_unknown_options_action_test.go +++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go @@ -23,21 +23,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - tb "gvisor.dev/gvisor/test/packetimpact/testbench" + "gvisor.dev/gvisor/test/packetimpact/testbench" ) func init() { - tb.RegisterFlags(flag.CommandLine) + testbench.RegisterFlags(flag.CommandLine) } -func mkHopByHopOptionsExtHdr(optType byte) tb.Layer { - return &tb.IPv6HopByHopOptionsExtHdr{ +func mkHopByHopOptionsExtHdr(optType byte) testbench.Layer { + return &testbench.IPv6HopByHopOptionsExtHdr{ Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, } } -func mkDestinationOptionsExtHdr(optType byte) tb.Layer { - return &tb.IPv6DestinationOptionsExtHdr{ +func mkDestinationOptionsExtHdr(optType byte) testbench.Layer { + return &testbench.IPv6DestinationOptionsExtHdr{ Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, } } @@ -49,7 +49,7 @@ func optionTypeFromAction(action header.IPv6OptionUnknownAction) byte { func TestIPv6UnknownOptionAction(t *testing.T) { for _, tt := range []struct { description string - mkExtHdr func(optType byte) tb.Layer + mkExtHdr func(optType byte) testbench.Layer action header.IPv6OptionUnknownAction multicastDst bool wantICMPv6 bool @@ -140,21 +140,21 @@ func TestIPv6UnknownOptionAction(t *testing.T) { }, } { t.Run(tt.description, func(t *testing.T) { - dut := tb.NewDUT(t) + dut := testbench.NewDUT(t) defer dut.TearDown() - ipv6Conn := tb.NewIPv6Conn(t, tb.IPv6{}, tb.IPv6{}) - conn := (*tb.Connection)(&ipv6Conn) - defer ipv6Conn.Close() + ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + conn := (*testbench.Connection)(&ipv6Conn) + defer ipv6Conn.Close(t) - outgoingOverride := tb.Layers{} + outgoingOverride := testbench.Layers{} if tt.multicastDst { - outgoingOverride = tb.Layers{&tb.IPv6{ - DstAddr: tb.Address(tcpip.Address(net.ParseIP("ff02::1"))), + outgoingOverride = testbench.Layers{&testbench.IPv6{ + DstAddr: testbench.Address(tcpip.Address(net.ParseIP("ff02::1"))), }} } - outgoing := conn.CreateFrame(outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action))) - conn.SendFrame(outgoing) + outgoing := conn.CreateFrame(t, outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action))) + conn.SendFrame(t, outgoing) ipv6Sent := outgoing[1:] invokingPacket, err := ipv6Sent.ToBytes() if err != nil { @@ -167,12 +167,12 @@ func TestIPv6UnknownOptionAction(t *testing.T) { // after the IPv6 header (after NextHeader and ExtHdrLen). binary.BigEndian.PutUint32(icmpv6Payload, header.IPv6MinimumSize+2) icmpv6Payload = append(icmpv6Payload, invokingPacket...) - gotICMPv6, err := ipv6Conn.ExpectFrame(tb.Layers{ - &tb.Ether{}, - &tb.IPv6{}, - &tb.ICMPv6{ - Type: tb.ICMPv6Type(header.ICMPv6ParamProblem), - Code: tb.Byte(2), + gotICMPv6, err := ipv6Conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), + Code: testbench.Byte(2), Payload: icmpv6Payload, }, }, time.Second) diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go index 6e7ff41d7..e6a96f214 100644 --- a/test/packetimpact/tests/tcp_close_wait_ack_test.go +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -33,39 +33,39 @@ func init() { func TestCloseWaitAck(t *testing.T) { for _, tt := range []struct { description string - makeTestingTCP func(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP seqNumOffset seqnum.Size expectAck bool }{ - {"OTW", GenerateOTWSeqSegment, 0, false}, - {"OTW", GenerateOTWSeqSegment, 1, true}, - {"OTW", GenerateOTWSeqSegment, 2, true}, - {"ACK", GenerateUnaccACKSegment, 0, false}, - {"ACK", GenerateUnaccACKSegment, 1, true}, - {"ACK", GenerateUnaccACKSegment, 2, true}, + {"OTW", generateOTWSeqSegment, 0, false}, + {"OTW", generateOTWSeqSegment, 1, true}, + {"OTW", generateOTWSeqSegment, 2, true}, + {"ACK", generateUnaccACKSegment, 0, false}, + {"ACK", generateUnaccACKSegment, 1, true}, + {"ACK", generateUnaccACKSegment, 2, true}, } { t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) // Send a FIN to DUT to intiate the active close - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) - gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) } windowSize := seqnum.Size(*gotTCP.WindowSize) // Send a segment with OTW Seq / unacc ACK and expect an ACK back - conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) - gotAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if tt.expectAck && err != nil { t.Fatalf("expected an ack but got none: %s", err) } @@ -74,35 +74,36 @@ func TestCloseWaitAck(t *testing.T) { } // Now let's verify DUT is indeed in CLOSE_WAIT - dut.Close(acceptFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + dut.Close(t, acceptFd) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { t.Fatalf("expected DUT to send a FIN: %s", err) } // Ack the FIN from DUT - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) // Send some extra data to DUT - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { t.Fatalf("expected DUT to send an RST: %s", err) } }) } } -// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the -// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK -// is expected from the receiver. -func GenerateOTWSeqSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.LocalSeqNum().Add(windowSize) +// generateOTWSeqSegment generates an segment with +// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only +// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the +// receiver. +func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} } -// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated -// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is -// expected from the receiver. -func GenerateUnaccACKSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.RemoteSeqNum() +// generateUnaccACKSegment generates an segment with +// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable +// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. +func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.RemoteSeqNum(t) unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} } diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go index fb8f48629..8feea4a82 100644 --- a/test/packetimpact/tests/tcp_cork_mss_test.go +++ b/test/packetimpact/tests/tcp_cork_mss_test.go @@ -32,53 +32,53 @@ func init() { func TestTCPCorkMSS(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) const mss = uint32(header.TCPDefaultMSS) options := make([]byte, header.TCPOptionMSSLength) header.EncodeMSSOption(mss, options) - conn.ConnectWithOptions(options) + conn.ConnectWithOptions(t, options) - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) - dut.SetSockOptInt(acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1) + dut.SetSockOptInt(t, acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1) // Let the dut application send 2 small segments to be held up and coalesced // until the application sends a larger segment to fill up to > MSS. sampleData := []byte("Sample Data") - dut.Send(acceptFD, sampleData, 0) - dut.Send(acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) expectedData := sampleData expectedData = append(expectedData, sampleData...) largeData := make([]byte, mss+1) expectedData = append(expectedData, largeData...) - dut.Send(acceptFD, largeData, 0) + dut.Send(t, acceptFD, largeData, 0) // Expect the segments to be coalesced and sent and capped to MSS. expectedPayload := testbench.Payload{Bytes: expectedData[:mss]} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the coalesced segment to be split and transmitted. expectedPayload = testbench.Payload{Bytes: expectedData[mss:]} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Check for segments to *not* be held up because of TCP_CORK when // the current send window is less than MSS. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) - dut.Send(acceptFD, sampleData, 0) - dut.Send(acceptFD, sampleData, 0) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) + dut.Send(t, acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) } diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go index 652b530d0..22937d92f 100644 --- a/test/packetimpact/tests/tcp_handshake_window_size_test.go +++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go @@ -33,14 +33,14 @@ func init() { func TestTCPHandshakeWindowSize(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Start handshake with zero window size. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK: %s", err) } // Update the advertised window size to a non-zero value with the ACK that @@ -48,10 +48,10 @@ func TestTCPHandshakeWindowSize(t *testing.T) { // // Set the window size with MSB set and expect the dut to treat it as // an unsigned value. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) - acceptFd, _ := dut.Accept(listenFD) - defer dut.Close(acceptFd) + acceptFd, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFd) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} @@ -59,8 +59,8 @@ func TestTCPHandshakeWindowSize(t *testing.T) { // Since we advertised a zero window followed by a non-zero window, // expect the dut to honor the recently advertised non-zero window // and actually send out the data instead of probing for zero window. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } } diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go index 868a08da8..900352fa1 100644 --- a/test/packetimpact/tests/tcp_network_unreachable_test.go +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -38,29 +38,29 @@ func TestTCPSynSentUnreachable(t *testing.T) { // Create the DUT and connection. dut := testbench.NewDUT(t) defer dut.TearDown() - clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) port := uint16(9001) conn := testbench.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port}) - defer conn.Close() + defer conn.Close(t) // Bring the DUT to SYN-SENT state with a non-blocking connect. ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) defer cancel() sa := unix.SockaddrInet4{Port: int(port)} copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv4)).To4()) - if _, err := dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) } // Get the SYN. - tcpLayers, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) if err != nil { t.Fatalf("expected SYN: %s", err) } // Send a host unreachable message. rawConn := (*testbench.Connection)(&conn) - layers := rawConn.CreateFrame(nil) + layers := rawConn.CreateFrame(t, nil) layers = layers[:len(layers)-1] const ipLayer = 1 const tcpLayer = ipLayer + 1 @@ -74,9 +74,9 @@ func TestTCPSynSentUnreachable(t *testing.T) { } var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4HostUnreachable)} layers = append(layers, &icmpv4, ip, tcp) - rawConn.SendFrameStateless(layers) + rawConn.SendFrameStateless(t, layers) - if _, err = dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) { + if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) { t.Errorf("expected connect to fail with EHOSTUNREACH, but got %v", err) } } @@ -88,9 +88,9 @@ func TestTCPSynSentUnreachable6(t *testing.T) { // Create the DUT and connection. dut := testbench.NewDUT(t) defer dut.TearDown() - clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6)) + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6)) conn := testbench.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort}) - defer conn.Close() + defer conn.Close(t) // Bring the DUT to SYN-SENT state with a non-blocking connect. ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) @@ -100,19 +100,19 @@ func TestTCPSynSentUnreachable6(t *testing.T) { ZoneId: uint32(testbench.RemoteInterfaceID), } copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv6)).To16()) - if _, err := dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) } // Get the SYN. - tcpLayers, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) if err != nil { t.Fatalf("expected SYN: %s", err) } // Send a host unreachable message. rawConn := (*testbench.Connection)(&conn) - layers := rawConn.CreateFrame(nil) + layers := rawConn.CreateFrame(t, nil) layers = layers[:len(layers)-1] const ipLayer = 1 const tcpLayer = ipLayer + 1 @@ -131,9 +131,9 @@ func TestTCPSynSentUnreachable6(t *testing.T) { Payload: []byte{0, 0, 0, 0}, } layers = append(layers, &icmpv6, ip, tcp) - rawConn.SendFrameStateless(layers) + rawConn.SendFrameStateless(t, layers) - if _, err = dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) { + if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) { t.Errorf("expected connect to fail with ENETUNREACH, but got %v", err) } } diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go index b9b3e91d3..82b7a85ff 100644 --- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -31,12 +31,12 @@ func init() { func TestTcpNoAcceptCloseReset(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - conn.Connect() - defer conn.Close() - dut.Close(listenFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + conn.Connect(t) + defer conn.Close(t) + dut.Close(t, listenFd) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { t.Fatalf("expected a RST-ACK packet but got none: %s", err) } } diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go index ad8c74234..08f759f7c 100644 --- a/test/packetimpact/tests/tcp_outside_the_window_test.go +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -63,25 +63,25 @@ func TestTCPOutsideTheWindow(t *testing.T) { t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) - windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset - conn.Drain() + windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + tt.seqNumOffset + conn.Drain(t) // Ignore whatever incrementing that this out-of-order packet might cause // to the AckNum. - localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum())) - conn.Send(testbench.TCP{ + localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + conn.Send(t, testbench.TCP{ Flags: testbench.Uint8(tt.tcpFlags), - SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))), + SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), }, tt.payload...) timeout := 3 * time.Second - gotACK, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) if tt.expectACK && err != nil { t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) } diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go index 55db4ece6..37f3b56dd 100644 --- a/test/packetimpact/tests/tcp_paws_mechanism_test.go +++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go @@ -32,15 +32,15 @@ func init() { func TestPAWSMechanism(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) options := make([]byte, header.TCPOptionTSLength) header.EncodeTSOption(currentTS(), 0, options) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) - synAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) + synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("didn't get synack during handshake: %s", err) } @@ -50,9 +50,9 @@ func TestPAWSMechanism(t *testing.T) { } tsecr := parsedSynOpts.TSVal header.EncodeTSOption(currentTS(), tsecr, options) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) sampleData := []byte("Sample Data") sentTSVal := currentTS() @@ -61,9 +61,9 @@ func TestPAWSMechanism(t *testing.T) { // every time we send one, it should not cause any flakiness because timestamps // only need to be non-decreasing. time.Sleep(3 * time.Millisecond) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) - gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected an ACK but got none: %s", err) } @@ -86,9 +86,9 @@ func TestPAWSMechanism(t *testing.T) { // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness // due to the exact same reasoning discussed above. time.Sleep(3 * time.Millisecond) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) - gotTCP, err = conn.Expect(testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err) } diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go index 8fbec893b..d9f3ea0f2 100644 --- a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go +++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go @@ -52,26 +52,26 @@ func TestQueueReceiveInSynSent(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - socket, remotePort := dut.CreateBoundSocket(unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) sampleData := []byte("Sample Data") - dut.SetNonBlocking(socket, true) - if _, err := dut.ConnectWithErrno(context.Background(), socket, conn.LocalAddr()); !errors.Is(err, syscall.EINPROGRESS) { + dut.SetNonBlocking(t, socket, true) + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) } - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { t.Fatalf("expected a SYN from DUT, but got none: %s", err) } - if _, _, err := dut.RecvWithErrno(context.Background(), socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { + if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) } // Test blocking read. - dut.SetNonBlocking(socket, false) + dut.SetNonBlocking(t, socket, false) var wg sync.WaitGroup defer wg.Wait() @@ -86,7 +86,7 @@ func TestQueueReceiveInSynSent(t *testing.T) { block.Done() // Issue RECEIVE call in SYN-SENT, this should be queued for // process until the connection is established. - n, buff, err := dut.RecvWithErrno(ctx, socket, int32(len(sampleData)), 0) + n, buff, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0) if tt.reset { if err != syscall.Errno(unix.ECONNREFUSED) { t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) @@ -112,19 +112,19 @@ func TestQueueReceiveInSynSent(t *testing.T) { time.Sleep(100 * time.Millisecond) if tt.reset { - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) return } // Bring the connection to Established. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK from DUT, but got none: %s", err) } // Send sample payload and expect an ACK. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK from DUT, but got none: %s", err) } }) diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go index a5378a9dd..8742819ca 100644 --- a/test/packetimpact/tests/tcp_reordering_test.go +++ b/test/packetimpact/tests/tcp_reordering_test.go @@ -32,10 +32,10 @@ func init() { func TestReorderingWindow(t *testing.T) { dut := tb.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Enable SACK. opts := make([]byte, 40) @@ -49,17 +49,17 @@ func TestReorderingWindow(t *testing.T) { const mss = minMTU - header.IPv4MinimumSize - header.TCPMinimumSize optsOff += header.EncodeMSSOption(mss, opts[optsOff:]) - conn.ConnectWithOptions(opts[:optsOff]) + conn.ConnectWithOptions(t, opts[:optsOff]) - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) if tb.DUTType == "linux" { // Linux has changed its handling of reordering, force the old behavior. - dut.SetSockOpt(acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) + dut.SetSockOpt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) } - pls := dut.GetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) + pls := dut.GetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) if tb.DUTType == "netstack" { // netstack does not impliment TCP_MAXSEG correctly. Fake it // here. Netstack uses the max SACK size which is 32. The MSS @@ -69,13 +69,13 @@ func TestReorderingWindow(t *testing.T) { payload := make([]byte, pls) - seqNum1 := *conn.RemoteSeqNum() + seqNum1 := *conn.RemoteSeqNum(t) const numPkts = 10 // Send some packets, checking that we receive each. for i, sn := 0, seqNum1; i < numPkts; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -86,7 +86,7 @@ func TestReorderingWindow(t *testing.T) { } } - seqNum2 := *conn.RemoteSeqNum() + seqNum2 := *conn.RemoteSeqNum(t) // SACK packets #2-4. sackBlock := make([]byte, 40) @@ -97,13 +97,13 @@ func TestReorderingWindow(t *testing.T) { seqNum1.Add(seqnum.Size(len(payload))), seqNum1.Add(seqnum.Size(4 * len(payload))), }}, sackBlock[sbOff:]) - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) // ACK first packet. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))}) // Check for retransmit. - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second) if err != nil { t.Error("Expect for retransmit:", err) } @@ -123,14 +123,14 @@ func TestReorderingWindow(t *testing.T) { seqNum1.Add(seqnum.Size(4 * len(payload))), }}, dsackBlock[dsbOff:]) - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]}) // Send half of the original window of packets, checking that we // received each. for i, sn := 0, seqNum2; i < numPkts/2; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -144,8 +144,8 @@ func TestReorderingWindow(t *testing.T) { if tb.DUTType == "netstack" { // The window should now be halved, so we should receive any // more, even if we send them. - dut.Send(acceptFd, payload, 0) - if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + dut.Send(t, acceptFd, payload, 0) + if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) } return @@ -153,9 +153,9 @@ func TestReorderingWindow(t *testing.T) { // Linux reduces the window by three. Check that we can receive the rest. for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -167,8 +167,8 @@ func TestReorderingWindow(t *testing.T) { } // The window should now be full. - dut.Send(acceptFd, payload, 0) - if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + dut.Send(t, acceptFd, payload, 0) + if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) } } diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go index 6940eb7fb..072014ff8 100644 --- a/test/packetimpact/tests/tcp_retransmits_test.go +++ b/test/packetimpact/tests/tcp_retransmits_test.go @@ -33,41 +33,41 @@ func init() { func TestRetransmits(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK. // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which // we can skip sending this ACK. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) startRTO := time.Second current := startRTO first := time.Now() - dut.Send(acceptFd, sampleData, 0) - seq := testbench.Uint32(uint32(*conn.RemoteSeqNum())) - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Expect retransmits of the same segment. for i := 0; i < 5; i++ { start := time.Now() - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil { t.Fatalf("expected payload was not received: %s loop %d", err, i) } if i == 0 { diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go index 90ab85419..f91b06ba1 100644 --- a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go +++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go @@ -61,23 +61,23 @@ func TestSendWindowSizesPiggyback(t *testing.T) { t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1} - if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } @@ -86,18 +86,18 @@ func TestSendWindowSizesPiggyback(t *testing.T) { if tt.enqueue { // Enqueue a segment for the dut to transmit. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) } // Send ACK for the previous segment along with data for the dut to // receive and ACK back. Sending this ACK would make room for the dut // to transmit any enqueued segment. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) // Expect the dut to piggyback the ACK for received data along with // the segment enqueued for transmit. expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2} - if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } }) diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go index 7d5deab01..57d034dd1 100644 --- a/test/packetimpact/tests/tcp_synrcvd_reset_test.go +++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go @@ -32,21 +32,21 @@ func init() { func TestTCPSynRcvdReset(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Expect dut connection to have transitioned to SYN-RCVD state. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST %s", err) } } diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go index 6898a2239..eac8eb19d 100644 --- a/test/packetimpact/tests/tcp_synsent_reset_test.go +++ b/test/packetimpact/tests/tcp_synsent_reset_test.go @@ -31,17 +31,19 @@ func init() { // dutSynSentState sets up the dut connection in SYN-SENT state. func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) { + t.Helper() + dut := tb.NewDUT(t) - clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4)) + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4)) port := uint16(9001) conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port}) sa := unix.SockaddrInet4{Port: int(port)} copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4()) // Bring the dut to SYN-SENT state with a non-blocking connect. - dut.Connect(clientFD, &sa) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { + dut.Connect(t, clientFD, &sa) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { t.Fatalf("expected SYN\n") } @@ -51,13 +53,13 @@ func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) { // TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition. func TestTCPSynSentReset(t *testing.T) { dut, conn, _, _ := dutSynSentState(t) - defer conn.Close() + defer conn.Close(t) defer dut.TearDown() - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) // Expect the connection to have closed. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } @@ -67,22 +69,22 @@ func TestTCPSynSentReset(t *testing.T) { func TestTCPSynSentRcvdReset(t *testing.T) { dut, c, remotePort, clientPort := dutSynSentState(t) defer dut.TearDown() - defer c.Close() + defer c.Close(t) conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Initiate new SYN connection with the same port pair // (simultaneous open case), expect the dut connection to move to // SYN-RCVD state - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK %s\n", err) } - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go index 87e45d765..551dc78e7 100644 --- a/test/packetimpact/tests/tcp_user_timeout_test.go +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -16,7 +16,6 @@ package tcp_user_timeout_test import ( "flag" - "fmt" "testing" "time" @@ -29,22 +28,20 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func sendPayload(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error { +func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) { sampleData := make([]byte, 100) for i := range sampleData { sampleData[i] = uint8(i) } - conn.Drain() - dut.Send(fd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { - return fmt.Errorf("expected data but got none: %w", err) + conn.Drain(t) + dut.Send(t, fd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { + t.Fatalf("expected data but got none: %w", err) } - return nil } -func sendFIN(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error { - dut.Close(fd) - return nil +func sendFIN(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) { + dut.Close(t, fd) } func TestTCPUserTimeout(t *testing.T) { @@ -59,7 +56,7 @@ func TestTCPUserTimeout(t *testing.T) { } { for _, ttf := range []struct { description string - f func(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error + f func(_ *testing.T, _ *testbench.TCPIPv4, _ *testbench.DUT, fd int32) }{ {"AfterPayload", sendPayload}, {"AfterFIN", sendFIN}, @@ -68,31 +65,29 @@ func TestTCPUserTimeout(t *testing.T) { // Create a socket, listen, TCP handshake, and accept. dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() - acceptFD, _ := dut.Accept(listenFD) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) if tt.userTimeout != 0 { - dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) + dut.SetSockOptInt(t, acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) } - if err := ttf.f(&conn, &dut, acceptFD); err != nil { - t.Fatal(err) - } + ttf.f(t, &conn, &dut, acceptFD) time.Sleep(tt.sendDelay) - conn.Drain() - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Drain(t) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) // If TCP_USER_TIMEOUT was set and the above delay was longer than the // TCP_USER_TIMEOUT then the DUT should send a RST in response to the // testbench's packet. expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout expectTimeout := 5 * time.Second - got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) + got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) if expectRST && err != nil { t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err) } diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go index e78d04756..5b001fbec 100644 --- a/test/packetimpact/tests/tcp_window_shrink_test.go +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -31,43 +31,43 @@ func init() { func TestWindowShrink(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - dut.Send(acceptFd, sampleData, 0) - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // We close our receiving window here - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) - dut.Send(acceptFd, []byte("Sample Data"), 0) + dut.Send(t, acceptFd, []byte("Sample Data"), 0) // Note: There is another kind of zero-window probing which Windows uses (by sending one // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change // the following lines. - expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1 - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + expectedRemoteSeqNum := *conn.RemoteSeqNum(t) - 1 + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", expectedRemoteSeqNum, err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go index 8c89d57c9..da93267d6 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -33,27 +33,27 @@ func init() { func TestZeroWindowProbeRetransmit(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -63,15 +63,15 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { // of the recorded first zero probe transmission duration. // // Advertize zero receive window again. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) - ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) startProbeDuration := time.Second current := startProbeDuration first := time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect the dut to keep the connection alive as long as the remote is // acknowledging the zero-window probes. for i := 0; i < 5; i++ { @@ -79,7 +79,7 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { // Expect zero-window probe with a timeout which is a function of the typical // first retransmission time. The retransmission times is supposed to // exponentially increase. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i) } if i == 0 { @@ -92,14 +92,13 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { t.Errorf("got zero probe %d after %s, want >= %s", i, got, want) } // Acknowledge the zero-window probes from the dut. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) current *= 2 } // Advertize non-zero window. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the dut to recover and transmit data. - if _, err := conn.ExpectData(&testbench. - TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go index 649fd5699..44cac42f8 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go @@ -33,29 +33,29 @@ func init() { func TestZeroWindowProbe(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} start := time.Now() // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } sendTime := time.Now().Sub(start) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -63,24 +63,24 @@ func TestZeroWindowProbe(t *testing.T) { // probe to be sent. // // Advertize zero window to the dut. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Expected sequence number of the zero window probe. - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) // Expected ack number of the ACK for the probe. - ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) // Expect there are no zero-window probes sent until there is data to be sent out // from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil { t.Fatalf("unexpected packet with sequence number %d: %s", probeSeq, err) } start = time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect zero-window probe from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) } // Expect the probe to be sent after some time. Compare against the previous @@ -94,9 +94,9 @@ func TestZeroWindowProbe(t *testing.T) { // and sends out the sample payload after the send window opens. // // Advertize non-zero window to the dut and ack the zero window probe. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the dut to recover and transmit data. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } @@ -104,9 +104,9 @@ func TestZeroWindowProbe(t *testing.T) { // Check if the dut responds as we do for a similar probe sent to it. // Basically with sequence number to one byte behind the unacknowledged // sequence number. - p := testbench.Uint32(uint32(*conn.LocalSeqNum())) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum() - 1))}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { + p := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { t.Fatalf("expected a packet with ack number: %d: %s", p, err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go index 3c467b14f..09a1c653f 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go @@ -33,27 +33,27 @@ func init() { func TestZeroWindowProbeUserTimeout(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -61,15 +61,15 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // probe to be sent. // // Advertize zero window to the dut. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Expected sequence number of the zero window probe. - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) start := time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect zero-window probe from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) } // Record the duration for first probe, the dut sends the zero window probe after @@ -80,19 +80,19 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // when the dut is sending zero-window probes. // // Reduce the retransmit timeout. - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) // Advertize zero window again. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Ask the dut to send out data that would trigger zero window probe retransmissions. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Wait for the connection to timeout after multiple zero-window probe retransmissions. time.Sleep(8 * startProbeDuration) // Expect the connection to have timed out and closed which would cause the dut // to reply with a RST to the ACK we send. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go index b0315e67c..d30177e64 100644 --- a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go +++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go @@ -36,11 +36,11 @@ func init() { func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4)) - defer dut.Close(remoteFD) - dut.SetSockOptTimeval(remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4)) + defer dut.Close(t, remoteFD) + dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) for _, mcastAddr := range []net.IP{ net.IPv4allsys, @@ -50,11 +50,12 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { } { t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { conn.SendIP( + t, testbench.IPv4{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To4()))}, testbench.UDP{}, ) - ret, payload, errno := dut.RecvWithErrno(context.Background(), remoteFD, 100, 0) + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) } @@ -65,11 +66,11 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6)) - defer dut.Close(remoteFD) - dut.SetSockOptTimeval(remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6)) + defer dut.Close(t, remoteFD) + dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) for _, mcastAddr := range []net.IP{ net.IPv6interfacelocalallnodes, @@ -80,10 +81,11 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) { } { t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { conn.SendIPv6( + t, testbench.IPv6{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To16()))}, testbench.UDP{}, ) - ret, payload, errno := dut.RecvWithErrno(context.Background(), remoteFD, 100, 0) + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) } diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go index b754918f6..715e8f5b5 100644 --- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -72,7 +72,7 @@ func (e icmpError) ToICMPv4() *testbench.ICMPv4 { type errorDetection struct { name string useValidConn bool - f func(context.Context, testData) error + f func(context.Context, *testing.T, testData) } type testData struct { @@ -95,12 +95,14 @@ func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno { } // sendICMPError sends an ICMP error message in response to a UDP datagram. -func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) error { - layers := (*testbench.Connection)(conn).CreateFrame(nil) +func sendICMPError(t *testing.T, conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) { + t.Helper() + + layers := (*testbench.Connection)(conn).CreateFrame(t, nil) layers = layers[:len(layers)-1] ip, ok := udp.Prev().(*testbench.IPv4) if !ok { - return fmt.Errorf("expected %s to be IPv4", udp.Prev()) + t.Fatalf("expected %s to be IPv4", udp.Prev()) } if icmpErr == timeToLiveExceeded { *ip.TTL = 1 @@ -114,84 +116,82 @@ func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UD // resulting in a mal-formed packet. layers = append(layers, icmpErr.ToICMPv4(), ip, udp) - (*testbench.Connection)(conn).SendFrameStateless(layers) - return nil + (*testbench.Connection)(conn).SendFrameStateless(t, layers) } // testRecv tests observing the ICMP error through the recv syscall. A packet // is sent to the DUT, and if wantErrno is non-zero, then the first recv should // fail and the second should succeed. Otherwise if wantErrno is zero then the // first recv should succeed immediately. -func testRecv(ctx context.Context, d testData) error { +func testRecv(ctx context.Context, t *testing.T, d testData) { + t.Helper() + // Check that receiving on the clean socket works. - d.conn.Send(testbench.UDP{DstPort: &d.cleanPort}) - d.dut.Recv(d.cleanFD, 100, 0) + d.conn.Send(t, testbench.UDP{DstPort: &d.cleanPort}) + d.dut.Recv(t, d.cleanFD, 100, 0) - d.conn.Send(testbench.UDP{}) + d.conn.Send(t, testbench.UDP{}) if d.wantErrno != syscall.Errno(0) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - ret, _, err := d.dut.RecvWithErrno(ctx, d.remoteFD, 100, 0) + ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0) if ret != -1 { - return fmt.Errorf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + t.Fatalf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) } if err != d.wantErrno { - return fmt.Errorf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + t.Fatalf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) } } - d.dut.Recv(d.remoteFD, 100, 0) - return nil + d.dut.Recv(t, d.remoteFD, 100, 0) } // testSendTo tests observing the ICMP error through the send syscall. If // wantErrno is non-zero, the first send should fail and a subsequent send // should suceed; while if wantErrno is zero then the first send should just // succeed. -func testSendTo(ctx context.Context, d testData) error { +func testSendTo(ctx context.Context, t *testing.T, d testData) { // Check that sending on the clean socket works. - d.dut.SendTo(d.cleanFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet from clean socket on DUT: %s", err) + d.dut.SendTo(t, d.cleanFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet from clean socket on DUT: %s", err) } if d.wantErrno != syscall.Errno(0) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - ret, err := d.dut.SendToWithErrno(ctx, d.remoteFD, nil, 0, d.conn.LocalAddr()) + ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) if ret != -1 { - return fmt.Errorf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + t.Fatalf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) } if err != d.wantErrno { - return fmt.Errorf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + t.Fatalf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) } } - d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet as expected: %s", err) + d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet as expected: %s", err) } - return nil } -func testSockOpt(_ context.Context, d testData) error { +func testSockOpt(_ context.Context, t *testing.T, d testData) { // Check that there's no pending error on the clean socket. - if errno := syscall.Errno(d.dut.GetSockOptInt(d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { - return fmt.Errorf("unexpected error (%[1]d) %[1]v on clean socket", errno) + if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { + t.Fatalf("unexpected error (%[1]d) %[1]v on clean socket", errno) } - if errno := syscall.Errno(d.dut.GetSockOptInt(d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { - return fmt.Errorf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) + if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { + t.Fatalf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) } // Check that after clearing socket error, sending doesn't fail. - d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet as expected: %s", err) + d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet as expected: %s", err) } - return nil } // TestUDPICMPErrorPropagation tests that ICMP error messages in response to @@ -227,31 +227,29 @@ func TestUDPICMPErrorPropagation(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(remoteFD) + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(t, remoteFD) // Create a second, clean socket on the DUT to ensure that the ICMP // error messages only affect the sockets they are intended for. - cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(cleanFD) + cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(t, cleanFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) if connect { - dut.Connect(remoteFD, conn.LocalAddr()) - dut.Connect(cleanFD, conn.LocalAddr()) + dut.Connect(t, remoteFD, conn.LocalAddr(t)) + dut.Connect(t, cleanFD, conn.LocalAddr(t)) } - dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) - udp, err := conn.Expect(testbench.UDP{}, time.Second) + dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t)) + udp, err := conn.Expect(t, testbench.UDP{}, time.Second) if err != nil { t.Fatalf("did not receive message from DUT: %s", err) } - if err := sendICMPError(&conn, icmpErr, udp); err != nil { - t.Fatal(err) - } + sendICMPError(t, &conn, icmpErr, udp) errDetectConn := &conn if errDetect.useValidConn { @@ -260,14 +258,12 @@ func TestUDPICMPErrorPropagation(t *testing.T) { // interactions between it and the the DUT should be independent of // the ICMP error at least at the port level. connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer connClean.Close() + defer connClean.Close(t) errDetectConn = &connClean } - if err := errDetect.f(context.Background(), testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}); err != nil { - t.Fatal(err) - } + errDetect.f(context.Background(), t, testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}) }) } } @@ -285,24 +281,24 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(remoteFD) + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(t, remoteFD) // Create a second, clean socket on the DUT to ensure that the ICMP // error messages only affect the sockets they are intended for. - cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(cleanFD) + cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(t, cleanFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) if connect { - dut.Connect(remoteFD, conn.LocalAddr()) - dut.Connect(cleanFD, conn.LocalAddr()) + dut.Connect(t, remoteFD, conn.LocalAddr(t)) + dut.Connect(t, cleanFD, conn.LocalAddr(t)) } - dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) - udp, err := conn.Expect(testbench.UDP{}, time.Second) + dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t)) + udp, err := conn.Expect(t, testbench.UDP{}, time.Second) if err != nil { t.Fatalf("did not receive message from DUT: %s", err) } @@ -316,7 +312,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0) + ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0) if ret != -1 { t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) return @@ -330,7 +326,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 { + if ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0); ret == -1 { t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err) } }() @@ -341,7 +337,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 { + if ret, _, err := dut.RecvWithErrno(ctx, t, cleanFD, 100, 0); ret == -1 { t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err) } }() @@ -352,12 +348,10 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { // alternative is available. time.Sleep(2 * time.Second) - if err := sendICMPError(&conn, icmpErr, udp); err != nil { - t.Fatal(err) - } + sendICMPError(t, &conn, icmpErr, udp) - conn.Send(testbench.UDP{DstPort: &cleanPort}) - conn.Send(testbench.UDP{}) + conn.Send(t, testbench.UDP{DstPort: &cleanPort}) + conn.Send(t, testbench.UDP{}) wg.Wait() }) } diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go index 263a54291..fcd202643 100644 --- a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go +++ b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go @@ -31,10 +31,10 @@ func init() { func TestUDPRecvMulticastBroadcast(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4(0, 0, 0, 0)) - defer dut.Close(boundFD) + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4(0, 0, 0, 0)) + defer dut.Close(t, boundFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) for _, bcastAddr := range []net.IP{ broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)), @@ -43,12 +43,13 @@ func TestUDPRecvMulticastBroadcast(t *testing.T) { } { payload := testbench.GenerateRandomPayload(t, 1<<10) conn.SendIP( + t, testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(bcastAddr.To4()))}, testbench.UDP{}, &testbench.Payload{Bytes: payload}, ) t.Logf("Receiving packet sent to address: %s", bcastAddr) - if got, want := string(dut.Recv(boundFD, int32(len(payload)), 0)), string(payload); got != want { + if got, want := string(dut.Recv(t, boundFD, int32(len(payload)), 0)), string(payload); got != want { t.Errorf("received payload does not match sent payload got: %s, want: %s", got, want) } } diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go index bd53ad90b..dc20275d6 100644 --- a/test/packetimpact/tests/udp_send_recv_dgram_test.go +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -29,10 +29,10 @@ func init() { } type udpConn interface { - Send(testbench.UDP, ...testbench.Layer) - ExpectData(testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error) - Drain() - Close() + Send(*testing.T, testbench.UDP, ...testbench.Layer) + ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error) + Drain(*testing.T) + Close(*testing.T) } func TestUDP(t *testing.T) { @@ -51,21 +51,21 @@ func TestUDP(t *testing.T) { } else { addr = testbench.RemoteIPv6 } - boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr)) - defer dut.Close(boundFD) + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr)) + defer dut.Close(t, boundFD) var conn udpConn var localAddr unix.Sockaddr if isIPv4 { v4Conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - localAddr = v4Conn.LocalAddr() + localAddr = v4Conn.LocalAddr(t) conn = &v4Conn } else { v6Conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - localAddr = v6Conn.LocalAddr() + localAddr = v6Conn.LocalAddr(t) conn = &v6Conn } - defer conn.Close() + defer conn.Close(t) testCases := []struct { name string @@ -81,17 +81,17 @@ func TestUDP(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Run("Send", func(t *testing.T) { - conn.Send(testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) - if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), string(tc.payload); got != want { + conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) + if got, want := string(dut.Recv(t, boundFD, int32(len(tc.payload)), 0)), string(tc.payload); got != want { t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want) } }) t.Run("Recv", func(t *testing.T) { - conn.Drain() - if got, want := int(dut.SendTo(boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want { + conn.Drain(t) + if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want { t.Fatalf("short write got: %d, want: %d", got, want) } - if _, err := conn.ExpectData(testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil { + if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil { t.Fatal(err) } }) diff --git a/test/runner/BUILD b/test/runner/BUILD index 1f45a6922..63c7ec83a 100644 --- a/test/runner/BUILD +++ b/test/runner/BUILD @@ -17,6 +17,7 @@ go_binary( "//test/runner/gtest", "//test/uds", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 600cb5192..c92392b35 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -157,7 +157,7 @@ def syscall_test( platform = "native", use_tmpfs = False, add_uds_tree = add_uds_tree, - tags = tags, + tags = list(tags), ) for (platform, platform_tags) in platforms.items(): diff --git a/test/runner/runner.go b/test/runner/runner.go index 2296f3a46..bc4b39cbb 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -30,6 +30,7 @@ import ( "time" specs "github.com/opencontainers/runtime-spec/specs-go" + "github.com/syndtr/gocapability/capability" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/test/testutil" @@ -105,6 +106,13 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + + if specutils.HasCapabilities(capability.CAP_NET_ADMIN) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Cloneflags: syscall.CLONE_NEWNET, + } + } + if err := cmd.Run(); err != nil { ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) diff --git a/tools/bazel.mk b/tools/bazel.mk index e27e907ab..54844ebbc 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -15,6 +15,7 @@ # limitations under the License. # See base Makefile. +SHELL=/bin/bash -o pipefail BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ xargs -n 1 basename 2>/dev/null) @@ -22,8 +23,11 @@ BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ # Bazel container configuration (see below). USER ?= gvisor HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) +BUILDER_BASE := gvisor.dev/images/default +BUILDER_IMAGE := gvisor.dev/images/builder +BUILDER_NAME ?= gvisor-builder-$(HASH) DOCKER_NAME ?= gvisor-bazel-$(HASH) -DOCKER_PRIVILEGED ?= --privileged --network host +DOCKER_PRIVILEGED ?= --privileged BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) DOCKER_SOCKET := /var/run/docker.sock @@ -32,17 +36,25 @@ DOCKER_SOCKET := /var/run/docker.sock OPTIONS += --test_output=errors --keep_going --verbose_failures=true BAZEL := bazel $(STARTUP_OPTIONS) -# Non-configurable. +# Basic options. UID := $(shell id -u ${USER}) GID := $(shell id -g ${USER}) USERADD_OPTIONS := FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS) +FULL_DOCKER_RUN_OPTIONS += --user $(UID):$(GID) +FULL_DOCKER_RUN_OPTIONS += --entrypoint "" +FULL_DOCKER_RUN_OPTIONS += --init FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)" FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)" FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp" +FULL_DOCKER_EXEC_OPTIONS := --user $(UID):$(GID) +FULL_DOCKER_EXEC_OPTIONS += -i + +# Add docker passthrough options. ifneq ($(DOCKER_PRIVILEGED),) FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED) +FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED) DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET)) ifneq ($(GID),$(DOCKER_GROUP)) USERADD_OPTIONS += --groups $(DOCKER_GROUP) @@ -50,7 +62,30 @@ GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) && FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP) endif endif -SHELL=/bin/bash -o pipefail + +# Add KVM passthrough options. +ifneq (,$(wildcard /dev/kvm)) +FULL_DOCKER_RUN_OPTIONS += --device=/dev/kvm +KVM_GROUP := $(shell stat -c '%g' /dev/kvm) +ifneq ($(GID),$(KVM_GROUP)) +USERADD_OPTIONS += --groups $(KVM_GROUP) +GROUPADD_DOCKER += groupadd --gid $(KVM_GROUP) --non-unique kvm-$(HASH) && +FULL_DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP) +endif +endif + +bazel-image: load-default + @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi + docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \ + $(BUILDER_BASE) \ + sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ + $(GROUPADD_DOCKER) \ + useradd --uid $(UID) --non-unique --no-create-home \ + --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ + if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi" + docker commit $(BUILDER_NAME) $(BUILDER_IMAGE) + @docker rm -f $(BUILDER_NAME) +.PHONY: bazel-image ## ## Bazel helpers. @@ -65,41 +100,37 @@ SHELL=/bin/bash -o pipefail ## GCLOUD_CONFIG - The gcloud config directory (detect: detected). ## DOCKER_SOCKET - The Docker socket (default: detected). ## -bazel-server-start: load-default ## Starts the bazel server. +bazel-server-start: bazel-image ## Starts the bazel server. @mkdir -p $(BAZEL_CACHE) @mkdir -p $(GCLOUD_CONFIG) @if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi - docker run -d --rm \ - --init \ - --name $(DOCKER_NAME) \ - --user 0:0 $(DOCKER_GROUP_OPTIONS) \ + # This command runs a bazel server, and the container sticks around + # until the bazel server exits. This should ensure that it does not + # exit in the middle of running a build, but also it won't stick around + # forever. The build commands wrap around an appropriate exec into the + # container in order to perform work via the bazel client. + docker run -d --rm --name $(DOCKER_NAME) \ -v "$(CURDIR):$(CURDIR)" \ --workdir "$(CURDIR)" \ - --entrypoint "" \ $(FULL_DOCKER_RUN_OPTIONS) \ - gvisor.dev/images/default \ - sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ - $(GROUPADD_DOCKER) \ - useradd --uid $(UID) --non-unique --no-create-home --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ - $(BAZEL) version && \ - exec tail --pid=\$$($(BAZEL) info server_pid) -f /dev/null" - @while :; do if docker logs $(DOCKER_NAME) 2>/dev/null | grep "Build label:" >/dev/null; then break; fi; \ - if ! docker ps | grep $(DOCKER_NAME); then docker logs $(DOCKER_NAME); exit 1; else sleep 1; fi; done + $(BUILDER_IMAGE) \ + sh -c "tail -f --pid=\$$($(BAZEL) info server_pid)" .PHONY: bazel-server-start bazel-shutdown: ## Shuts down a running bazel server. - @docker exec --user $(UID):$(GID) $(DOCKER_NAME) $(BAZEL) shutdown; rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) shutdown; \ + rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] .PHONY: bazel-shutdown bazel-alias: ## Emits an alias that can be used within the shell. - @echo "alias bazel='docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel'" + @echo "alias bazel='docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) bazel'" .PHONY: bazel-alias bazel-server: ## Ensures that the server exists. Used as an internal target. - @docker exec $(DOCKER_NAME) true || $(MAKE) bazel-server-start + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) true || $(MAKE) bazel-server-start .PHONY: bazel-server -build_cmd = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) $(TARGETS)' +build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) $(TARGETS)' build_paths = $(build_cmd) 2>&1 \ | tee /proc/self/fd/2 \ @@ -126,9 +157,9 @@ sudo: bazel-server .PHONY: sudo test: bazel-server - @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) $(BAZEL) test $(OPTIONS) $(TARGETS) + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) test $(OPTIONS) $(TARGETS) .PHONY: test query: bazel-server - @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) $(BAZEL) query $(OPTIONS) '$(TARGETS)' + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) query $(OPTIONS) '$(TARGETS)' .PHONY: query diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD index 32a949c93..558826bf1 100644 --- a/tools/go_generics/BUILD +++ b/tools/go_generics/BUILD @@ -12,27 +12,3 @@ go_binary( visibility = ["//:sandbox"], deps = ["//tools/go_generics/globals"], ) - -genrule( - name = "go_generics_tests", - srcs = glob(["generics_tests/**"]) + [":go_generics"], - outs = ["go_generics_tests.tgz"], - cmd = "tar -czvhf $@ $(SRCS)", -) - -genrule( - name = "go_generics_test_bundle", - srcs = [ - ":go_generics_tests.tgz", - ":go_generics_unittest.sh", - ], - outs = ["go_generics_test.sh"], - cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@", - executable = True, -) - -sh_test( - name = "go_generics_test", - size = "small", - srcs = ["go_generics_test.sh"], -) diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl index ec047a644..33329cf28 100644 --- a/tools/go_generics/defs.bzl +++ b/tools/go_generics/defs.bzl @@ -100,20 +100,21 @@ def _go_template_instance_impl(ctx): # Build the argument list. args = ["-i=%s" % template.file.path, "-o=%s" % output.path] - args += ["-p=%s" % ctx.attr.package] + if ctx.attr.package: + args.append("-p=%s" % ctx.attr.package) if len(ctx.attr.prefix) > 0: - args += ["-prefix=%s" % ctx.attr.prefix] + args.append("-prefix=%s" % ctx.attr.prefix) if len(ctx.attr.suffix) > 0: - args += ["-suffix=%s" % ctx.attr.suffix] + args.append("-suffix=%s" % ctx.attr.suffix) args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()] args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()] args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()] if ctx.attr.anon: - args += ["-anon"] + args.append("-anon") ctx.actions.run( inputs = [template.file], @@ -151,7 +152,7 @@ go_template_instance = rule( "consts": attr.string_dict(), "imports": attr.string_dict(), "anon": attr.bool(mandatory = False, default = False), - "package": attr.string(mandatory = True), + "package": attr.string(mandatory = False), "out": attr.output(mandatory = True), "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")), }, diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt deleted file mode 100644 index c9d0e09bf..000000000 --- a/tools/go_generics/generics_tests/all_stmts/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt deleted file mode 100644 index c9d0e09bf..000000000 --- a/tools/go_generics/generics_tests/all_types/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q diff --git a/tools/go_generics/generics_tests/anon/opts.txt b/tools/go_generics/generics_tests/anon/opts.txt deleted file mode 100644 index a5e9d26de..000000000 --- a/tools/go_generics/generics_tests/anon/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q -suffix=New -anon diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt deleted file mode 100644 index 4fb59dce8..000000000 --- a/tools/go_generics/generics_tests/consts/opts.txt +++ /dev/null @@ -1 +0,0 @@ --c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC" diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt deleted file mode 100644 index 87324be79..000000000 --- a/tools/go_generics/generics_tests/imports/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt deleted file mode 100644 index 9c8ecaada..000000000 --- a/tools/go_generics/generics_tests/remove_typedef/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=U diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt deleted file mode 100644 index 7832ef66f..000000000 --- a/tools/go_generics/generics_tests/simple/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q -suffix=New diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh deleted file mode 100755 index 44b22db91..000000000 --- a/tools/go_generics/go_generics_unittest.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Bash "safe-mode": Treat command failures as fatal (even those that occur in -# pipes), and treat unset variables as errors. -set -eu -o pipefail - -# This file will be generated as a self-extracting shell script in order to -# eliminate the need for any runtime dependencies. The tarball at the end will -# include the go_generics binary, as well as a subdirectory named -# generics_tests. See the BUILD file for more information. -declare -r temp=$(mktemp -d) -function cleanup() { - rm -rf "${temp}" -} -# trap cleanup EXIT - -# Print message in "$1" then exit with status 1. -function die () { - echo "$1" 1>&2 - exit 1 -} - -# This prints the line number of __BUNDLE__ below, that should be the last line -# of this script. After that point, the concatenated archive will be the -# contents. -declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0` -tail -n+"${tgz}" $0 | tar -xzv -C "${temp}" - -# The target for the test. -declare -r binary="$(find ${temp} -type f -a -name go_generics)" -declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*" - -# Go through all test cases. -for f in ${input_dirs}; do - base=$(basename "${f}") - - # Run go_generics on the input file. - opts=$(head -n 1 ${f}/opts.txt) - out="${f}/output/generated.go" - expected="${f}/output/output.go" - ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\"" - - # Compare the outputs. - diff ${expected} ${out} - if [ $? -ne 0 ]; then - echo "Expected:" - cat ${expected} - echo "Actual:" - cat ${out} - die "Actual output is different from expected for test \"${base}\"" - fi -done - -echo "PASS" -exit 0 -__BUNDLE__ diff --git a/tools/go_generics/tests/BUILD b/tools/go_generics/tests/BUILD new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tools/go_generics/tests/BUILD diff --git a/tools/go_generics/tests/all_stmts/BUILD b/tools/go_generics/tests/all_stmts/BUILD new file mode 100644 index 000000000..a4a7c775a --- /dev/null +++ b/tools/go_generics/tests/all_stmts/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "all_stmts", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/tests/all_stmts/input.go index 4791d1ff1..4791d1ff1 100644 --- a/tools/go_generics/generics_tests/all_stmts/input.go +++ b/tools/go_generics/tests/all_stmts/input.go diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/tests/all_stmts/output.go index a53d84535..a53d84535 100644 --- a/tools/go_generics/generics_tests/all_stmts/output/output.go +++ b/tools/go_generics/tests/all_stmts/output.go diff --git a/tools/go_generics/tests/all_types/BUILD b/tools/go_generics/tests/all_types/BUILD new file mode 100644 index 000000000..60b1fd314 --- /dev/null +++ b/tools/go_generics/tests/all_types/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "all_types", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/tests/all_types/input.go index 3575d02ec..6f85bbb69 100644 --- a/tools/go_generics/generics_tests/all_types/input.go +++ b/tools/go_generics/tests/all_types/input.go @@ -14,7 +14,9 @@ package tests -import "./lib" +import ( + "./lib" +) type T int diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/tests/all_types/lib/lib.go index 988786496..988786496 100644 --- a/tools/go_generics/generics_tests/all_types/lib/lib.go +++ b/tools/go_generics/tests/all_types/lib/lib.go diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/tests/all_types/output.go index 41fd147a1..c0bbebfe7 100644 --- a/tools/go_generics/generics_tests/all_types/output/output.go +++ b/tools/go_generics/tests/all_types/output.go @@ -14,7 +14,9 @@ package main -import "./lib" +import ( + "./lib" +) type newType struct { a Q diff --git a/tools/go_generics/tests/anon/BUILD b/tools/go_generics/tests/anon/BUILD new file mode 100644 index 000000000..ef24f4b25 --- /dev/null +++ b/tools/go_generics/tests/anon/BUILD @@ -0,0 +1,18 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "anon", + anon = True, + inputs = ["input.go"], + output = "output.go", + suffix = "New", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/anon/input.go b/tools/go_generics/tests/anon/input.go index 44086d522..44086d522 100644 --- a/tools/go_generics/generics_tests/anon/input.go +++ b/tools/go_generics/tests/anon/input.go diff --git a/tools/go_generics/generics_tests/anon/output/output.go b/tools/go_generics/tests/anon/output.go index 160cddf79..7fa791853 100644 --- a/tools/go_generics/generics_tests/anon/output/output.go +++ b/tools/go_generics/tests/anon/output.go @@ -35,8 +35,8 @@ func (f FooNew) GetBar(name string) Q { func foobarNew() { a := BazNew{} - a.Q = 0 // should not be renamed, this is a limitation + a.Q = 0 b := otherpkg.UnrelatedType{} - b.Q = 0 // should not be renamed, this is a limitation + b.Q = 0 } diff --git a/tools/go_generics/tests/consts/BUILD b/tools/go_generics/tests/consts/BUILD new file mode 100644 index 000000000..fd7caccad --- /dev/null +++ b/tools/go_generics/tests/consts/BUILD @@ -0,0 +1,23 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "consts", + consts = { + "c1": "20", + "z": "600", + "v": "3.3", + "s": "\"def\"", + "A": "20", + "C": "100", + "S": "\"def\"", + "T": "\"ABC\"", + }, + inputs = ["input.go"], + output = "output.go", +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/tests/consts/input.go index 04b95fcc6..04b95fcc6 100644 --- a/tools/go_generics/generics_tests/consts/input.go +++ b/tools/go_generics/tests/consts/input.go diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/tests/consts/output.go index 18d316cc9..18d316cc9 100644 --- a/tools/go_generics/generics_tests/consts/output/output.go +++ b/tools/go_generics/tests/consts/output.go diff --git a/tools/go_generics/tests/defs.bzl b/tools/go_generics/tests/defs.bzl new file mode 100644 index 000000000..6277c3947 --- /dev/null +++ b/tools/go_generics/tests/defs.bzl @@ -0,0 +1,67 @@ +"""Generics tests.""" + +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") + +def _go_generics_test_impl(ctx): + runner = ctx.actions.declare_file(ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "exec diff --ignore-blank-lines --ignore-matching-lines=^[[:space:]]*// %s %s" % ( + ctx.files.template_output[0].short_path, + ctx.files.expected_output[0].short_path, + ), + "", + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + return [DefaultInfo( + executable = runner, + runfiles = ctx.runfiles( + files = ctx.files.template_output + ctx.files.expected_output, + collect_default = True, + collect_data = True, + ), + )] + +_go_generics_test = rule( + implementation = _go_generics_test_impl, + attrs = { + "template_output": attr.label(mandatory = True, allow_single_file = True), + "expected_output": attr.label(mandatory = True, allow_single_file = True), + }, + test = True, +) + +def go_generics_test(name, inputs, output, types = None, consts = None, **kwargs): + """Instantiates a generics test. + + Args: + name: the name of the test. + inputs: all the input files. + output: the output files. + types: the template types (dictionary). + consts: the template consts (dictionary). + **kwargs: additional arguments for the template_instance. + """ + if types == None: + types = dict() + if consts == None: + consts = dict() + go_template( + name = name + "_template", + srcs = inputs, + types = types.keys(), + consts = consts.keys(), + ) + go_template_instance( + name = name + "_output", + template = ":" + name + "_template", + out = name + "_output.go", + types = types, + consts = consts, + **kwargs + ) + _go_generics_test( + name = name + "_test", + template_output = name + "_output.go", + expected_output = output, + ) diff --git a/tools/go_generics/tests/imports/BUILD b/tools/go_generics/tests/imports/BUILD new file mode 100644 index 000000000..a86223d41 --- /dev/null +++ b/tools/go_generics/tests/imports/BUILD @@ -0,0 +1,24 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "imports", + consts = { + "n": "math.Uint32", + "m": "math.Uint64", + }, + imports = { + "sync": "sync", + "math": "mymathpath", + }, + inputs = ["input.go"], + output = "output.go", + types = { + "T": "sync.Mutex", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/tests/imports/input.go index 0f032c2a1..0f032c2a1 100644 --- a/tools/go_generics/generics_tests/imports/input.go +++ b/tools/go_generics/tests/imports/input.go diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/tests/imports/output.go index 2488ca58c..2488ca58c 100644 --- a/tools/go_generics/generics_tests/imports/output/output.go +++ b/tools/go_generics/tests/imports/output.go diff --git a/tools/go_generics/tests/remove_typedef/BUILD b/tools/go_generics/tests/remove_typedef/BUILD new file mode 100644 index 000000000..46457cec6 --- /dev/null +++ b/tools/go_generics/tests/remove_typedef/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "remove_typedef", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "U", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/tests/remove_typedef/input.go index cf632bae7..cf632bae7 100644 --- a/tools/go_generics/generics_tests/remove_typedef/input.go +++ b/tools/go_generics/tests/remove_typedef/input.go diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/tests/remove_typedef/output.go index d44fd8e1c..d44fd8e1c 100644 --- a/tools/go_generics/generics_tests/remove_typedef/output/output.go +++ b/tools/go_generics/tests/remove_typedef/output.go diff --git a/tools/go_generics/tests/simple/BUILD b/tools/go_generics/tests/simple/BUILD new file mode 100644 index 000000000..4b9265ea4 --- /dev/null +++ b/tools/go_generics/tests/simple/BUILD @@ -0,0 +1,17 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "simple", + inputs = ["input.go"], + output = "output.go", + suffix = "New", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/tests/simple/input.go index 2a917f16c..2a917f16c 100644 --- a/tools/go_generics/generics_tests/simple/input.go +++ b/tools/go_generics/tests/simple/input.go diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/tests/simple/output.go index 6bfa0b25b..6bfa0b25b 100644 --- a/tools/go_generics/generics_tests/simple/output/output.go +++ b/tools/go_generics/tests/simple/output.go |