diff options
91 files changed, 2353 insertions, 1266 deletions
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 006b5a525..29062c97a 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -170,3 +170,18 @@ const ( KCOV_MODE_TRACE_PC = 2 KCOV_MODE_TRACE_CMP = 3 ) + +// Attestation ioctls. +var ( + SIGN_ATTESTATION_REPORT = IOC(_IOC_READ, 's', 1, 65) +) + +// SizeOfQuoteInputData is the number of bytes in the input data of ioctl call +// to get quote. +const SizeOfQuoteInputData = 64 + +// SignReport is a struct that gets signed quote from input data. +type SignReport struct { + data [64]byte + quote []byte +} diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index 69eeb7528..4d5e062a8 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -37,6 +37,14 @@ package cpuid // arch/arm64/include/uapi/asm/hwcap.h type Feature int +// HostFeatureSet returns a FeatureSet that matches that of the host machine. +// Callers must not mutate the returned FeatureSet. +func HostFeatureSet() *FeatureSet { + return hostFeatureSet +} + +var hostFeatureSet = getHostFeatureSet() + // ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a // subset of the host feature set. type ErrIncompatible struct { diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go index 6e61d562f..04645aed5 100644 --- a/pkg/cpuid/cpuid_arm64.go +++ b/pkg/cpuid/cpuid_arm64.go @@ -230,6 +230,16 @@ type FeatureSet struct { CPURevision uint8 } +// Clone returns a copy of fs. +func (fs *FeatureSet) Clone() *FeatureSet { + fs2 := *fs + fs2.Set = make(map[Feature]bool) + for f, b := range fs.Set { + fs2.Set[f] = b + } + return &fs2 +} + // CheckHostCompatible returns nil if fs is a subset of the host feature set. // Noop on arm64. func (fs *FeatureSet) CheckHostCompatible() error { @@ -292,9 +302,9 @@ func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) { fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline. } -// HostFeatureSet uses hwCap to get host values and construct a feature set +// getHostFeatureSet uses hwCap to get host values and construct a feature set // that matches that of the host machine. -func HostFeatureSet() *FeatureSet { +func getHostFeatureSet() *FeatureSet { s := make(map[Feature]bool) for f := range arm64FeatureStrings { diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go index dc17cade8..a92d32d74 100644 --- a/pkg/cpuid/cpuid_x86.go +++ b/pkg/cpuid/cpuid_x86.go @@ -627,6 +627,17 @@ type FeatureSet struct { CacheLine uint32 } +// Clone returns a copy of fs. +func (fs *FeatureSet) Clone() *FeatureSet { + fs2 := *fs + fs2.Set = make(map[Feature]bool) + for f, b := range fs.Set { + fs2.Set[f] = b + } + fs2.Caches = append([]Cache(nil), fs.Caches...) + return &fs2 +} + // FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is // equivalent to the "flags" field in /proc/cpuinfo. func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string { @@ -961,13 +972,13 @@ func (fs *FeatureSet) UseXsaveopt() bool { // HostID executes a native CPUID instruction. func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32) -// HostFeatureSet uses cpuid to get host values and construct a feature set +// getHostFeatureSet uses cpuid to get host values and construct a feature set // that matches that of the host machine. Note that there are several places // where there appear to be some unnecessary assignments between register names // (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show // where the different feature blocks come from, to make the code easier to // inspect and read. -func HostFeatureSet() *FeatureSet { +func getHostFeatureSet() *FeatureSet { // eax=0 gets max supported feature and vendor ID. _, bx, cx, dx := HostID(0, 0) vendorID := vendorIDFromRegs(bx, cx, dx) diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD index 6cdd21b1b..1f371e513 100644 --- a/pkg/sentry/arch/fpu/BUILD +++ b/pkg/sentry/arch/fpu/BUILD @@ -9,6 +9,7 @@ go_library( "fpu_amd64.go", "fpu_amd64.s", "fpu_arm64.go", + "fpu_unsafe.go", ], visibility = ["//:sandbox"], deps = [ diff --git a/pkg/sentry/arch/fpu/fpu.go b/pkg/sentry/arch/fpu/fpu.go index 867d309a3..62bde19d3 100644 --- a/pkg/sentry/arch/fpu/fpu.go +++ b/pkg/sentry/arch/fpu/fpu.go @@ -17,7 +17,6 @@ package fpu import ( "fmt" - "reflect" ) // State represents floating point state. @@ -40,15 +39,3 @@ type ErrLoadingState struct { func (e ErrLoadingState) Error() string { return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supportedFeatures, e.savedFeatures) } - -// alignedBytes returns a slice of size bytes, aligned in memory to the given -// alignment. This is used because we require certain structures to be aligned -// in a specific way (for example, the X86 floating point data). -func alignedBytes(size, alignment uint) []byte { - data := make([]byte, size+alignment-1) - offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment)) - if offset == 0 { - return data[:size:size] - } - return data[alignment-offset:][:size:size] -} diff --git a/pkg/sentry/arch/fpu/fpu_unsafe.go b/pkg/sentry/arch/fpu/fpu_unsafe.go new file mode 100644 index 000000000..c91dc99be --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_unsafe.go @@ -0,0 +1,31 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fpu + +import ( + "unsafe" +) + +// alignedBytes returns a slice of size bytes, aligned in memory to the given +// alignment. This is used because we require certain structures to be aligned +// in a specific way (for example, the X86 floating point data). +func alignedBytes(size, alignment uint) []byte { + data := make([]byte, size+alignment-1) + offset := uint(uintptr(unsafe.Pointer(&data[0])) % uintptr(alignment)) + if offset == 0 { + return data[:size:size] + } + return data[alignment-offset:][:size:size] +} diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go index 379429ab2..75dc5d204 100644 --- a/pkg/sentry/fs/proc/exec_args.go +++ b/pkg/sentry/fs/proc/exec_args.go @@ -107,7 +107,7 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen return 0, linuxerr.EINVAL } - m, err := getTaskMM(f.t) + m, err := getTaskMMIncRef(f.t) if err != nil { return 0, err } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 89a799b21..03f2a882d 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -41,10 +41,24 @@ import ( // LINT.IfChange -// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's -// users count is incremented, and must be decremented by the caller when it is -// no longer in use. -func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) { +// getTaskMM gets the kernel task's MemoryManager. No additional reference is +// taken on mm here. This is safe because MemoryManager.destroy is required to +// leave the MemoryManager in a state where it's still usable as a +// DynamicBytesSource. +func getTaskMM(t *kernel.Task) *mm.MemoryManager { + var tmm *mm.MemoryManager + t.WithMuLocked(func(t *kernel.Task) { + if mm := t.MemoryManager(); mm != nil { + tmm = mm + } + }) + return tmm +} + +// getTaskMMIncRef returns t's MemoryManager. If getTaskMMIncRef succeeds, the +// MemoryManager's users count is incremented, and must be decremented by the +// caller when it is no longer in use. +func getTaskMMIncRef(t *kernel.Task) (*mm.MemoryManager, error) { if t.ExitState() == kernel.TaskExitDead { return nil, linuxerr.ESRCH } @@ -269,21 +283,18 @@ func (e *exe) executable() (file fsbridge.File, err error) { if err := checkTaskState(e.t); err != nil { return nil, err } - e.t.WithMuLocked(func(t *kernel.Task) { - mm := t.MemoryManager() - if mm == nil { - err = linuxerr.EACCES - return - } + mm := getTaskMM(e.t) + if mm == nil { + return nil, linuxerr.EACCES + } - // The MemoryManager may be destroyed, in which case - // MemoryManager.destroy will simply set the executable to nil - // (with locks held). - file = mm.Executable() - if file == nil { - err = linuxerr.ESRCH - } - }) + // The MemoryManager may be destroyed, in which case + // MemoryManager.destroy will simply set the executable to nil + // (with locks held). + file = mm.Executable() + if file == nil { + err = linuxerr.ESRCH + } return } @@ -463,7 +474,7 @@ func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen if dst.NumBytes() == 0 { return 0, nil } - mm, err := getTaskMM(m.t) + mm, err := getTaskMMIncRef(m.t) if err != nil { return 0, nil } @@ -494,22 +505,9 @@ func newMaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inod return newProcInode(ctx, seqfile.NewSeqFile(ctx, &mapsData{t}), msrc, fs.SpecialFile, t) } -func (md *mapsData) mm() *mm.MemoryManager { - var tmm *mm.MemoryManager - md.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - // No additional reference is taken on mm here. This is safe - // because MemoryManager.destroy is required to leave the - // MemoryManager in a state where it's still usable as a SeqSource. - tmm = mm - } - }) - return tmm -} - // NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. func (md *mapsData) NeedsUpdate(generation int64) bool { - if mm := md.mm(); mm != nil { + if mm := getTaskMM(md.t); mm != nil { return mm.NeedsUpdate(generation) } return true @@ -517,7 +515,7 @@ func (md *mapsData) NeedsUpdate(generation int64) bool { // ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. func (md *mapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { - if mm := md.mm(); mm != nil { + if mm := getTaskMM(md.t); mm != nil { return mm.ReadMapsSeqFileData(ctx, h) } return []seqfile.SeqData{}, 0 @@ -534,22 +532,9 @@ func newSmaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Ino return newProcInode(ctx, seqfile.NewSeqFile(ctx, &smapsData{t}), msrc, fs.SpecialFile, t) } -func (sd *smapsData) mm() *mm.MemoryManager { - var tmm *mm.MemoryManager - sd.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - // No additional reference is taken on mm here. This is safe - // because MemoryManager.destroy is required to leave the - // MemoryManager in a state where it's still usable as a SeqSource. - tmm = mm - } - }) - return tmm -} - // NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. func (sd *smapsData) NeedsUpdate(generation int64) bool { - if mm := sd.mm(); mm != nil { + if mm := getTaskMM(sd.t); mm != nil { return mm.NeedsUpdate(generation) } return true @@ -557,7 +542,7 @@ func (sd *smapsData) NeedsUpdate(generation int64) bool { // ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. func (sd *smapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { - if mm := sd.mm(); mm != nil { + if mm := getTaskMM(sd.t); mm != nil { return mm.ReadSmapsSeqFileData(ctx, h) } return []seqfile.SeqData{}, 0 @@ -627,12 +612,10 @@ func (s *taskStatData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) fmt.Fprintf(&buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime()))) var vss, rss uint64 - s.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) + if mm := getTaskMM(s.t); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } fmt.Fprintf(&buf, "%d %d ", vss, rss/hostarch.PageSize) // rsslim. @@ -677,12 +660,10 @@ func (s *statmData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([ } var vss, rss uint64 - s.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) + if mm := getTaskMM(s.t); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } var buf bytes.Buffer fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize) @@ -734,12 +715,13 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( if fdTable := t.FDTable(); fdTable != nil { fds = fdTable.CurrentMaxFDs() } - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - data = mm.VirtualDataSize() - } }) + + if mm := getTaskMM(s.t); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + data = mm.VirtualDataSize() + } fmt.Fprintf(&buf, "FDSize:\t%d\n", fds) fmt.Fprintf(&buf, "VmSize:\t%d kB\n", vss>>10) fmt.Fprintf(&buf, "VmRSS:\t%d kB\n", rss>>10) @@ -925,7 +907,7 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc return 0, linuxerr.EINVAL } - m, err := getTaskMM(f.t) + m, err := getTaskMMIncRef(f.t) if err != nil { return 0, err } diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index bd6b30397..43440ec19 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -995,7 +995,7 @@ func (d *dentry) refreshSizeLocked(ctx context.Context) error { if d.writeFD < 0 { d.handleMu.RUnlock() // Ask the gofer if we don't have a host FD. - return d.updateFromGetattrLocked(ctx) + return d.updateFromGetattrLocked(ctx, p9file{}) } var stat unix.Statx_t @@ -1014,33 +1014,35 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { // updating stale attributes in d.updateFromP9AttrsLocked(). d.metadataMu.Lock() defer d.metadataMu.Unlock() - return d.updateFromGetattrLocked(ctx) + return d.updateFromGetattrLocked(ctx, p9file{}) } // Preconditions: // * !d.isSynthetic(). // * d.metadataMu is locked. // +checklocks:d.metadataMu -func (d *dentry) updateFromGetattrLocked(ctx context.Context) error { - // Use d.readFile or d.writeFile, which represent 9P FIDs that have been - // opened, in preference to d.file, which represents a 9P fid that has not. - // This may be significantly more efficient in some implementations. Prefer - // d.writeFile over d.readFile since some filesystem implementations may - // update a writable handle's metadata after writes to that handle, without - // making metadata updates immediately visible to read-only handles - // representing the same file. - d.handleMu.RLock() - handleMuRLocked := true - var file p9file - switch { - case !d.writeFile.isNil(): - file = d.writeFile - case !d.readFile.isNil(): - file = d.readFile - default: - file = d.file - d.handleMu.RUnlock() - handleMuRLocked = false +func (d *dentry) updateFromGetattrLocked(ctx context.Context, file p9file) error { + handleMuRLocked := false + if file.isNil() { + // Use d.readFile or d.writeFile, which represent 9P FIDs that have + // been opened, in preference to d.file, which represents a 9P fid that + // has not. This may be significantly more efficient in some + // implementations. Prefer d.writeFile over d.readFile since some + // filesystem implementations may update a writable handle's metadata + // after writes to that handle, without making metadata updates + // immediately visible to read-only handles representing the same file. + d.handleMu.RLock() + switch { + case !d.writeFile.isNil(): + file = d.writeFile + handleMuRLocked = true + case !d.readFile.isNil(): + file = d.readFile + handleMuRLocked = true + default: + file = d.file + d.handleMu.RUnlock() + } } _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask()) @@ -2044,9 +2046,17 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu d := fd.dentry() const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME) if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC { - // TODO(jamieliu): Use specialFileFD.handle.file for the getattr if - // available? - if err := d.updateFromGetattr(ctx); err != nil { + // Use specialFileFD.handle.file for the getattr if available, for the + // same reason that we try to use open file handles in + // dentry.updateFromGetattrLocked(). + var file p9file + if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok { + file = sffd.handle.file + } + d.metadataMu.Lock() + err := d.updateFromGetattrLocked(ctx, file) + d.metadataMu.Unlock() + if err != nil { return linux.Statx{}, err } } diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 618092ef1..520487066 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -36,6 +36,10 @@ func (d *dentry) isCopiedUp() bool { // // Preconditions: filesystem.renameMu must be locked. func (d *dentry) copyUpLocked(ctx context.Context) error { + return d.copyUpMaybeSyntheticMountpointLocked(ctx, false /* forSyntheticMountpoint */) +} + +func (d *dentry) copyUpMaybeSyntheticMountpointLocked(ctx context.Context, forSyntheticMountpoint bool) error { // Fast path. if d.isCopiedUp() { return nil @@ -59,7 +63,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { // d is a filesystem root with no upper layer. return linuxerr.EROFS } - if err := d.parent.copyUpLocked(ctx); err != nil { + if err := d.parent.copyUpMaybeSyntheticMountpointLocked(ctx, forSyntheticMountpoint); err != nil { return err } @@ -168,7 +172,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { case linux.S_IFDIR: if err := vfsObj.MkdirAt(ctx, d.fs.creds, &newpop, &vfs.MkdirOptions{ - Mode: linux.FileMode(d.mode &^ linux.S_IFMT), + Mode: linux.FileMode(d.mode &^ linux.S_IFMT), + ForSyntheticMountpoint: forSyntheticMountpoint, }); err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index c04c80590..3b3dcf836 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -462,13 +462,21 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, return d, nil } +type createType int + +const ( + createNonDirectory createType = iota + createDirectory + createSyntheticMountpoint +) + // doCreateAt checks that creating a file at rp is permitted, then invokes // create to do so. // // Preconditions: // * !rp.Done(). // * For the final path component in rp, !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error { +func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, ct createType, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) @@ -504,7 +512,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return linuxerr.EEXIST } - if !dir && rp.MustBeDir() { + if ct == createNonDirectory && rp.MustBeDir() { return linuxerr.ENOENT } @@ -518,7 +526,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir } // Ensure that the parent directory is copied-up so that we can create the // new file in the upper layer. - if err := parent.copyUpLocked(ctx); err != nil { + if err := parent.copyUpMaybeSyntheticMountpointLocked(ctx, ct == createSyntheticMountpoint); err != nil { return err } @@ -529,7 +537,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir parent.dirents = nil ev := linux.IN_CREATE - if dir { + if ct != createNonDirectory { ev |= linux.IN_ISDIR } parent.watches.Notify(ctx, name, uint32(ev), 0 /* cookie */, vfs.InodeEvent, false /* unlinked */) @@ -618,7 +626,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error { + return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error { if rp.Mount() != vd.Mount() { return linuxerr.EXDEV } @@ -671,7 +679,11 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error { + ct := createDirectory + if opts.ForSyntheticMountpoint { + ct = createSyntheticMountpoint + } + return fs.doCreateAt(ctx, rp, ct, func(parent *dentry, childName string, haveUpperWhiteout bool) error { vfsObj := fs.vfsfs.VirtualFilesystem() pop := vfs.PathOperation{ Root: parent.upperVD, @@ -722,7 +734,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error { + return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error { // Disallow attempts to create whiteouts. if opts.Mode&linux.S_IFMT == linux.S_IFCHR && opts.DevMajor == 0 && opts.DevMinor == 0 { return linuxerr.EPERM @@ -1476,7 +1488,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error { + return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error { vfsObj := fs.vfsfs.VirtualFilesystem() pop := vfs.PathOperation{ Root: parent.upperVD, diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 34b0c4f63..d3f9cf489 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -40,7 +40,7 @@ import ( // Linux 3.18, the limit is five lines." - user_namespaces(7) const maxIDMapLines = 5 -// mm gets the kernel task's MemoryManager. No additional reference is taken on +// getMM gets the kernel task's MemoryManager. No additional reference is taken on // mm here. This is safe because MemoryManager.destroy is required to leave the // MemoryManager in a state where it's still usable as a DynamicBytesSource. func getMM(task *kernel.Task) *mm.MemoryManager { @@ -608,12 +608,10 @@ func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime()))) var vss, rss uint64 - s.task.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) + if mm := getMM(s.task); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } fmt.Fprintf(buf, "%d %d ", vss, rss/hostarch.PageSize) // rsslim. @@ -649,13 +647,10 @@ var _ dynamicInode = (*statmData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { var vss, rss uint64 - s.task.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) - + if mm := getMM(s.task); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize) return nil } @@ -779,12 +774,12 @@ func (s *statusFD) Generate(ctx context.Context, buf *bytes.Buffer) error { if fdTable := t.FDTable(); fdTable != nil { fds = fdTable.CurrentMaxFDs() } - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - data = mm.VirtualDataSize() - } }) + if mm := getMM(s.task); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + data = mm.VirtualDataSize() + } // Filesystem user/group IDs aren't implemented; effective UID/GID are used // instead. fmt.Fprintf(buf, "Uid:\t%d\t%d\t%d\t%d\n", ruid, euid, suid, euid) @@ -945,25 +940,17 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent return vfs.VirtualDentry{}, "", err } - var err error - var exec fsbridge.File - s.task.WithMuLocked(func(t *kernel.Task) { - mm := t.MemoryManager() - if mm == nil { - err = linuxerr.EACCES - return - } + mm := getMM(s.task) + if mm == nil { + return vfs.VirtualDentry{}, "", linuxerr.EACCES + } - // The MemoryManager may be destroyed, in which case - // MemoryManager.destroy will simply set the executable to nil - // (with locks held). - exec = mm.Executable() - if exec == nil { - err = linuxerr.ESRCH - } - }) - if err != nil { - return vfs.VirtualDentry{}, "", err + // The MemoryManager may be destroyed, in which case + // MemoryManager.destroy will simply set the executable to nil + // (with locks held). + exec := mm.Executable() + if exec == nil { + return vfs.VirtualDentry{}, "", linuxerr.ESRCH } defer exec.DecRef(ctx) diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 9175b911c..db91fc4d8 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -222,9 +222,15 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { // Update credentials to reflect the execve. This should precede switching // MMs to ensure that dumpability has been reset first, if needed. t.updateCredsForExecLocked() - t.image.release() + oldImage := t.image t.image = *r.image t.mu.Unlock() + + // Don't hold t.mu while calling t.image.release(), that may + // attempt to acquire TaskImage.MemoryManager.mappingMu, a lock order + // violation. + oldImage.release() + t.unstopVforkParent() t.p.FullStateChanged() // NOTE(b/30316266): All locks must be dropped prior to calling Activate. diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 342e5debe..b3931445b 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -230,9 +230,16 @@ func (*runExitMain) execute(t *Task) taskRunState { t.tg.pidns.owner.mu.Lock() t.updateRSSLocked() t.tg.pidns.owner.mu.Unlock() + + // Release the task image resources. Accessing these fields must be + // done with t.mu held, but the mm.DecUsers() call must be done outside + // of that lock. t.mu.Lock() - t.image.release() + mm := t.image.MemoryManager + t.image.MemoryManager = nil + t.image.fu = nil t.mu.Unlock() + mm.DecUsers(t) // Releasing the MM unblocks a blocked CLONE_VFORK parent. t.unstopVforkParent() diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go index c132c27ef..6002ffb42 100644 --- a/pkg/sentry/kernel/task_image.go +++ b/pkg/sentry/kernel/task_image.go @@ -53,7 +53,7 @@ type TaskImage struct { } // release releases all resources held by the TaskImage. release is called by -// the task when it execs into a new TaskImage or exits. +// the task when it execs into a new TaskImage. func (image *TaskImage) release() { // Nil out pointers so that if the task is saved after release, it doesn't // follow the pointers to possibly now-invalid objects. diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 6a356779c..2759ef71e 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -295,15 +295,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V m.SetEnvvEnd(sl.EnvvEnd) m.SetAuxv(auxv) m.SetExecutable(ctx, file) - - symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn") - if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to find rt_sigreturn in vdso: %v", err), syserr.FromError(err).ToLinux()) - } - - // Found rt_sigretrun. - addr := uint64(vdsoAddr) + symbolValue - vdsoPrelink - m.SetVDSOSigReturn(addr) + m.SetVDSOSigReturn(uint64(vdsoAddr) + vdsoSigreturnOffset - vdsoPrelink) ac.SetIP(uintptr(loaded.entry)) ac.SetStack(uintptr(stack.Bottom)) diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index 3abd2ee7d..bcee6aef6 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -19,7 +19,6 @@ import ( "debug/elf" "fmt" "io" - "strings" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/context" @@ -177,27 +176,6 @@ type VDSO struct { phdrs []elf.ProgHeader `state:".([]elfProgHeader)"` } -// getSymbolValueFromVDSO returns the specific symbol value in vdso.so. -func getSymbolValueFromVDSO(symbol string) (uint64, error) { - f, err := elf.NewFile(bytes.NewReader(vdsodata.Binary)) - if err != nil { - return 0, err - } - syms, err := f.Symbols() - if err != nil { - return 0, err - } - - for _, sym := range syms { - if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF { - if strings.Contains(sym.Name, symbol) { - return sym.Value, nil - } - } - } - return 0, fmt.Errorf("no %v in vdso.so", symbol) -} - // PrepareVDSO validates the system VDSO and returns a VDSO, containing the // param page for updating by the kernel. func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { @@ -388,3 +366,21 @@ func (v *VDSO) Release(ctx context.Context) { v.ParamPage.DecRef(ctx) v.vdso.DecRef(ctx) } + +var vdsoSigreturnOffset = func() uint64 { + f, err := elf.NewFile(bytes.NewReader(vdsodata.Binary)) + if err != nil { + panic(fmt.Sprintf("failed to parse vdso.so as ELF file: %v", err)) + } + syms, err := f.Symbols() + if err != nil { + panic(fmt.Sprintf("failed to read symbols from vdso.so: %v", err)) + } + const sigreturnSymbol = "__kernel_rt_sigreturn" + for _, sym := range syms { + if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF && sym.Name == sigreturnSymbol { + return sym.Value + } + } + panic(fmt.Sprintf("no symbol %q in vdso.so", sigreturnSymbol)) +}() diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 57969b26c..0fca59b64 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -28,6 +28,7 @@ // memmap.File locks // mm.aioManager.mu // mm.AIOContext.mu +// kernel.TaskSet.mu // // Only mm.MemoryManager.Fork is permitted to lock mm.MemoryManager.activeMu in // multiple mm.MemoryManagers, as it does so in a well-defined order (forked diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 9f4cc238f..05cdcd8ae 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -324,20 +324,37 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter panic(fmt.Sprintf("pma %v needs to be copied for writing, but is not readable: %v", pseg.Range(), oldpma)) } } - // The majority of copy-on-write breaks on executable pages - // come from: - // - // - The ELF loader, which must zero out bytes on the last - // page of each segment after the end of the segment. - // - // - gdb's use of ptrace to insert breakpoints. - // - // Neither of these cases has enough spatial locality to - // benefit from copying nearby pages, so if the vma is - // executable, only copy the pages required. var copyAR hostarch.AddrRange - if vseg.ValuePtr().effectivePerms.Execute { + if vma := vseg.ValuePtr(); vma.effectivePerms.Execute { + // The majority of copy-on-write breaks on executable + // pages come from: + // + // - The ELF loader, which must zero out bytes on the + // last page of each segment after the end of the + // segment. + // + // - gdb's use of ptrace to insert breakpoints. + // + // Neither of these cases has enough spatial locality + // to benefit from copying nearby pages, so if the vma + // is executable, only copy the pages required. copyAR = pseg.Range().Intersect(ar) + } else if vma.growsDown { + // In most cases, the new process will not use most of + // its stack before exiting or invoking execve(); it is + // especially unlikely to return very far down its call + // stack, since async-signal-safety concerns in + // multithreaded programs prevent the new process from + // being able to do much. So only copy up to one page + // before and after the pages required. + stackMaskAR := ar + if newStart := stackMaskAR.Start - hostarch.PageSize; newStart < stackMaskAR.Start { + stackMaskAR.Start = newStart + } + if newEnd := stackMaskAR.End + hostarch.PageSize; newEnd > stackMaskAR.End { + stackMaskAR.End = newEnd + } + copyAR = pseg.Range().Intersect(stackMaskAR) } else { copyAR = pseg.Range().Intersect(maskAR) } diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e347442e7..bf5ec4558 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/usermem", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 2f9462cee..8cf2f29e4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -59,8 +59,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -2045,7 +2045,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { return syserr.ErrInvalidEndpointState - } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + } else if isUDPSocket(skType, skProto) && transport.DatagramEndpointState(ep.State()) != transport.DatagramEndpointStateInitial { return syserr.ErrInvalidEndpointState } @@ -3331,10 +3331,10 @@ func (s *socketOpsCommon) State() uint32 { } case isUDPSocket(s.skType, s.protocol): // UDP socket. - switch udp.EndpointState(s.Endpoint.State()) { - case udp.StateInitial, udp.StateBound, udp.StateClosed: + switch transport.DatagramEndpointState(s.Endpoint.State()) { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateBound, transport.DatagramEndpointStateClosed: return linux.TCP_CLOSE - case udp.StateConnected: + case transport.DatagramEndpointStateConnected: return linux.TCP_ESTABLISHED default: return 0 diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 658e90bb9..83b9d9389 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -750,6 +750,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: Ntohs(a.Protocol), }, family, nil case linux.AF_UNSPEC: diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e4de44498..a9cedcf5f 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -133,7 +133,7 @@ func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, f free := q.limit - q.used if l > free && truncate { - if free == 0 { + if free <= 0 { // Message can't fit right now. q.mu.Unlock() return 0, false, syserr.ErrWouldBlock diff --git a/pkg/shim/runtimeoptions/runtimeoptions.go b/pkg/shim/runtimeoptions/runtimeoptions.go index 072dd87f0..e76d73ea7 100644 --- a/pkg/shim/runtimeoptions/runtimeoptions.go +++ b/pkg/shim/runtimeoptions/runtimeoptions.go @@ -15,3 +15,10 @@ // Package runtimeoptions contains the runtimeoptions proto. package runtimeoptions + +import proto "github.com/gogo/protobuf/proto" + +func init() { + // TODO(gvisor.dev/issue/6449): Upgrade runtimeoptions.proto after upgrading to containerd 1.5 + proto.RegisterType((*Options)(nil), "runtimeoptions.v1.Options") +} diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD index 1cd5d641d..d8c4c9613 100644 --- a/pkg/syserr/BUILD +++ b/pkg/syserr/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/abi/linux/errno", "//pkg/errors", "//pkg/errors/linuxerr", + "//pkg/safecopy", "//pkg/tcpip", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go index a5e386e38..b679f3046 100644 --- a/pkg/syserr/syserr.go +++ b/pkg/syserr/syserr.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/errors" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/safecopy" ) // Error represents an internal error. @@ -278,15 +279,18 @@ func FromError(err error) *Error { if err == nil { return nil } - if errno, ok := err.(unix.Errno); ok { - return FromHost(errno) - } - if linuxErr, ok := err.(*errors.Error); ok { - return FromHost(unix.Errno(linuxErr.Errno())) + switch e := err.(type) { + case unix.Errno: + return FromHost(e) + case *errors.Error: + return FromHost(unix.Errno(e.Errno())) + case safecopy.SegvError, safecopy.BusError, safecopy.AlignmentError: + return FromHost(unix.EFAULT) } - panic("unknown error: " + err.Error()) + msg := fmt.Sprintf("err: %s type: %T", err.Error(), err) + panic(msg) } // ConvertIntr converts the provided error code (err) to another one (intr) if diff --git a/pkg/tcpip/internal/tcp/BUILD b/pkg/tcpip/internal/tcp/BUILD new file mode 100644 index 000000000..9ae258a0b --- /dev/null +++ b/pkg/tcpip/internal/tcp/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "tcp", + srcs = ["tcp.go"], + visibility = ["//pkg/tcpip:__subpackages__"], + deps = [ + "//pkg/tcpip", + ], +) diff --git a/pkg/tcpip/internal/tcp/tcp.go b/pkg/tcpip/internal/tcp/tcp.go new file mode 100644 index 000000000..0616d368c --- /dev/null +++ b/pkg/tcpip/internal/tcp/tcp.go @@ -0,0 +1,48 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcp contains internal type definitions that are not expected to be +// used by anyone else outside pkg/tcpip. +package tcp + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// TSOffset is an offset applied to the value of the TSVal field in the TCP +// Timestamp option. +// +// +stateify savable +type TSOffset struct { + milliseconds uint32 +} + +// NewTSOffset creates a new TSOffset from milliseconds. +func NewTSOffset(milliseconds uint32) TSOffset { + return TSOffset{ + milliseconds: milliseconds, + } +} + +// TSVal applies the offset to now and returns the timestamp in milliseconds. +func (offset TSOffset) TSVal(now tcpip.MonotonicTime) uint32 { + return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + offset.milliseconds +} + +// Elapsed calculates the elapsed time given now and the echoed back timestamp. +func (offset TSOffset) Elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration { + return time.Duration(offset.TSVal(now)-tsEcr) * time.Millisecond +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index f26c857eb..d02eea93c 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -290,3 +290,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { // AddHeader implements stack.LinkEndpoint.AddHeader. func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } + +// WriteRawPacket implements stack.LinkEndpoint. +func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index b9db273d0..8211a2031 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -112,3 +112,8 @@ func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkP } eth.Encode(&fields) } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + return e.Endpoint.WriteRawPacket(pkt) +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 48356c343..058242f96 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -505,6 +505,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net } } +// WriteRawPacket implements stack.LinkEndpoint. +func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } + // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 7012d8829..ca1f9c08d 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,19 +76,8 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - // Construct data as the unparsed portion for the loopback packet. - data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - - // Because we're immediately turning around and writing the packet back - // to the rx path, we intentionally don't preserve the remote and local - // link addresses from the stack.Route we're passed. - newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: data, - }) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt) - - return nil +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + return e.WriteRawPacket(pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -103,3 +92,19 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + // Construct data as the unparsed portion for the loopback packet. + data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + + // Because we're immediately turning around and writing the packet back + // to the rx path, we intentionally don't preserve the remote and local + // link addresses from the stack.Route we're passed. + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, + }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, pkt.NetworkProtocolNumber, newPkt) + + return nil +} diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 3e2a1aa94..844f5959b 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -131,6 +131,11 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { } +// WriteRawPacket implements stack.LinkEndpoint. +func (*InjectableEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 3e816b0c7..14cb96d63 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -152,3 +152,8 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.child.AddHeader(local, remote, protocol, pkt) } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + return e.child.WriteRawPacket(pkt) +} diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 5030b6ba1..3ed0aa3fe 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -121,3 +121,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { // AddHeader implements stack.LinkEndpoint. func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { } + +// WriteRawPacket implements stack.LinkEndpoint. +func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 40bd5560b..dc63e5fb0 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -228,3 +228,8 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.lower.AddHeader(local, remote, protocol, pkt) } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + return e.lower.WriteRawPacket(pkt) +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 30cf659b8..66efe6472 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -202,6 +202,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net eth.Encode(ethHdr) } +// WriteRawPacket implements stack.LinkEndpoint. +func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } + // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index a95602aa5..13900205d 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -155,3 +155,6 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.lower.AddHeader(local, remote, protocol, pkt) } + +// WriteRawPacket implements stack.LinkEndpoint. +func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index a71400ee9..b0e4237bd 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -80,6 +80,11 @@ func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBuffe return pkts.Len(), nil } +// WriteRawPacket implements stack.LinkEndpoint. +func (*countedEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { panic("unimplemented") diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index 605e9ef8d..4d4d98caf 100644 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -101,6 +101,11 @@ func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return heade func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { } +// WriteRawPacket implements stack.LinkEndpoint. +func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // MakeRandPkt generates a randomized packet. transportHeaderLength indicates // how many random bytes will be copied in the Transport Header. // extraHeaderReserveLength indicates how much extra space will be reserved for diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index e0847e58a..6c42ab29b 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -85,6 +85,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/internal/tcp", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/transport/tcpconntrack", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 72f66441f..ccb69393b 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -342,6 +342,10 @@ func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, p return n, nil } +func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index b854d868c..ddc1ddab6 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -734,10 +734,29 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.mu.RUnlock() // Deliver to interested packet endpoints without holding NIC lock. + var packetEPPkt *PacketBuffer deliverPacketEPs := func(ep PacketEndpoint) { - p := pkt.Clone() - p.PktType = tcpip.PacketHost - ep.HandlePacket(n.id, local, protocol, p) + if packetEPPkt == nil { + // Packet endpoints hold the full packet. + // + // We perform a deep copy because higher-level endpoints may point to + // the middle of a view that is held by a packet endpoint. Save/Restore + // does not support overlapping slices and will panic in this case. + // + // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports + // overlapping slices (e.g. by passing a shallow copy of pkt to the packet + // endpoint). + packetEPPkt = NewPacketBuffer(PacketBufferOptions{ + Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(), + }) + // If a link header was populated in the original packet buffer, then + // populate it in the packet buffer we provide to packet endpoints as + // packet endpoints inspect link headers. + packetEPPkt.LinkHeader().Consume(pkt.LinkHeader().View().Size()) + packetEPPkt.PktType = tcpip.PacketHost + } + + ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone()) } if protoEPs != nil { protoEPs.forEach(deliverPacketEPs) @@ -758,13 +777,30 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc eps := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() + var packetEPPkt *PacketBuffer eps.forEach(func(ep PacketEndpoint) { - p := pkt.Clone() - p.PktType = tcpip.PacketOutgoing - // Add the link layer header as outgoing packets are intercepted - // before the link layer header is created. - n.LinkEndpoint.AddHeader(local, remote, protocol, p) - ep.HandlePacket(n.id, local, protocol, p) + if packetEPPkt == nil { + // Packet endpoints hold the full packet. + // + // We perform a deep copy because higher-level endpoints may point to + // the middle of a view that is held by a packet endpoint. Save/Restore + // does not support overlapping slices and will panic in this case. + // + // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports + // overlapping slices (e.g. by passing a shallow copy of pkt to the packet + // endpoint). + packetEPPkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: pkt.AvailableHeaderBytes(), + Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), + }) + // Add the link layer header as outgoing packets are intercepted before + // the link layer header is created and packet endpoints are interested + // in the link header. + n.LinkEndpoint.AddHeader(local, remote, protocol, packetEPPkt) + packetEPPkt.PktType = tcpip.PacketOutgoing + } + + ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone()) }) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index dfe2c886f..57b3348b2 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -846,6 +846,14 @@ type LinkEndpoint interface { // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + + // WriteRawPacket writes a packet directly to the link. + // + // If the link-layer has its own header, the payload must already include the + // header. + // + // WriteRawPacket takes ownership of the packet. + WriteRawPacket(*PacketBuffer) tcpip.Error } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8e5c6edbf..cb741e540 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -72,7 +72,8 @@ type Stack struct { // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. - rawFactory RawFactory + rawFactory RawFactory + packetEndpointWriteSupported bool demux *transportDemuxer @@ -218,6 +219,10 @@ type Options struct { // this is non-nil. RawFactory RawFactory + // AllowPacketEndpointWrite determines if packet endpoints support write + // operations. + AllowPacketEndpointWrite bool + // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data // returned by the stack secure RNG. @@ -359,23 +364,24 @@ func New(opts Options) *Stack { opts.NUDConfigs.resetInvalidFields() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*nic), - defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), - seed: seed, - nudConfigs: opts.NUDConfigs, - uniqueIDGenerator: opts.UniqueID, - nudDisp: opts.NUDDisp, - randomGenerator: randomGenerator, - secureRNG: opts.SecureRNG, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + nics: make(map[tcpip.NICID]*nic), + packetEndpointWriteSupported: opts.AllowPacketEndpointWrite, + defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: opts.IPTables, + icmpRateLimiter: NewICMPRateLimiter(), + seed: seed, + nudConfigs: opts.NUDConfigs, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + randomGenerator: randomGenerator, + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -1653,9 +1659,27 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, ReserveHeaderBytes: int(nic.MaxHeaderLength()), Data: payload, }) + pkt.NetworkProtocolNumber = netProto return nic.WritePacketToRemote(remote, netProto, pkt) } +// WriteRawPacket writes data directly to the specified NIC without adding any +// headers. +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + if !ok { + return &tcpip.ErrUnknownNICID{} + } + + pkt := NewPacketBuffer(PacketBufferOptions{ + Data: payload, + }) + pkt.NetworkProtocolNumber = proto + return nic.WriteRawPacket(pkt) +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1823,6 +1847,13 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol return nic.setNUDConfigs(proto, c) } +// Seed returns a 32 bit value that can be used as a seed value. +// +// NOTE: The seed is generated once during stack initialization only. +func (s *Stack) Seed() uint32 { + return s.seed +} + // Rand returns a reference to a pseudo random generator that can be used // to generate random numbers as required. func (s *Stack) Rand() *rand.Rand { @@ -1940,3 +1971,9 @@ func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProto return false } + +// PacketEndpointWriteSupported returns true iff packet endpoints support write +// operations. +func (s *Stack) PacketEndpointWriteSupported() bool { + return s.packetEndpointWriteSupported +} diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index 93ea83cdc..dc7289441 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/internal/tcp" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -402,7 +403,7 @@ type TCPSndBufState struct { type TCPEndpointStateInner struct { // TSOffset is a randomized offset added to the value of the TSVal // field in the timestamp option. - TSOffset uint32 + TSOffset tcp.TSOffset // SACKPermitted is set to true if the peer sends the TCPSACKPermitted // option in the SYN/SYN-ACK. diff --git a/pkg/tcpip/transport/BUILD b/pkg/tcpip/transport/BUILD new file mode 100644 index 000000000..af332ed91 --- /dev/null +++ b/pkg/tcpip/transport/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "transport", + srcs = [ + "datagram.go", + "transport.go", + ], + visibility = ["//visibility:public"], + deps = ["//pkg/tcpip"], +) diff --git a/pkg/tcpip/transport/datagram.go b/pkg/tcpip/transport/datagram.go new file mode 100644 index 000000000..dfce72c69 --- /dev/null +++ b/pkg/tcpip/transport/datagram.go @@ -0,0 +1,49 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// DatagramEndpointState is the state of a datagram-based endpoint. +type DatagramEndpointState tcpip.EndpointState + +// The states a datagram-based endpoint may be in. +const ( + _ DatagramEndpointState = iota + DatagramEndpointStateInitial + DatagramEndpointStateBound + DatagramEndpointStateConnected + DatagramEndpointStateClosed +) + +// String implements fmt.Stringer. +func (s DatagramEndpointState) String() string { + switch s { + case DatagramEndpointStateInitial: + return "INITIAL" + case DatagramEndpointStateBound: + return "BOUND" + case DatagramEndpointStateConnected: + return "CONNECTED" + case DatagramEndpointStateClosed: + return "CLOSED" + default: + panic(fmt.Sprintf("unhandled %[1]T variant = %[1]d", s)) + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index f9a15efb2..00497bf07 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -329,6 +329,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp route = r } + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { return 0, &tcpip.ErrBadBuffer{} diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD new file mode 100644 index 000000000..d10e3f13a --- /dev/null +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "network", + srcs = [ + "endpoint.go", + "endpoint_state.go", + ], + visibility = [ + "//pkg/tcpip/transport/udp:__pkg__", + ], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + ], +) + +go_test( + name = "network_test", + size = "small", + srcs = ["endpoint_test.go"], + deps = [ + ":network", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/udp", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go new file mode 100644 index 000000000..0dce60d89 --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -0,0 +1,722 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package network provides facilities to support tcpip.Endpoints that operate +// at the network layer or above. +package network + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" +) + +// Endpoint is a datagram-based endpoint. It only supports sending datagrams to +// a peer. +// +// +stateify savable +type Endpoint struct { + // The following fields must only be set once then never changed. + stack *stack.Stack `state:"manual"` + ops *tcpip.SocketOptions + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + + // state holds a transport.DatagramBasedEndpointState. + // + // state must be read from/written to atomically. + state uint32 + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + info stack.TransportEndpointInfo + // owner is the owner of transmitted packets. + owner tcpip.PacketOwner + writeShutdown bool + effectiveNetProto tcpip.NetworkProtocolNumber + connectedRoute *stack.Route `state:"manual"` + multicastMemberships map[multicastMembership]struct{} + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + ttl uint8 + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + multicastTTL uint8 + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + multicastAddr tcpip.Address + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + multicastNICID tcpip.NICID + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + sendTOS uint8 +} + +// +stateify savable +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + +// Init initializes the endpoint. +func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) { + if e.multicastMemberships != nil { + panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships)) + } + + switch netProto { + case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber: + default: + panic(fmt.Sprintf("invalid protocol number = %d", netProto)) + } + + *e = Endpoint{ + stack: s, + ops: ops, + netProto: netProto, + transProto: transProto, + + state: uint32(transport.DatagramEndpointStateInitial), + + info: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, + effectiveNetProto: netProto, + // Linux defaults to TTL=1. + multicastTTL: 1, + multicastMemberships: make(map[multicastMembership]struct{}), + } +} + +// NetProto returns the network protocol the endpoint was initialized with. +func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { + return e.netProto +} + +// setState sets the state of the endpoint. +func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { + atomic.StoreUint32(&e.state, uint32(state)) +} + +// State returns the state of the endpoint. +func (e *Endpoint) State() transport.DatagramEndpointState { + return transport.DatagramEndpointState(atomic.LoadUint32(&e.state)) +} + +// Close cleans the endpoint's resources and leaves the endpoint in a closed +// state. +func (e *Endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.State() == transport.DatagramEndpointStateClosed { + return + } + + for mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + + if e.connectedRoute != nil { + e.connectedRoute.Release() + e.connectedRoute = nil + } + + e.setEndpointState(transport.DatagramEndpointStateClosed) +} + +// SetOwner sets the owner of transmitted packets. +func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) { + e.mu.Lock() + defer e.mu.Unlock() + e.owner = owner +} + +func calculateTTL(route *stack.Route, ttl uint8, multicastTTL uint8) uint8 { + if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { + return multicastTTL + } + + if ttl == 0 { + return route.DefaultTTL() + } + + return ttl +} + +// WriteContext holds the context for a write. +type WriteContext struct { + transProto tcpip.TransportProtocolNumber + route *stack.Route + ttl uint8 + tos uint8 + owner tcpip.PacketOwner +} + +// Release releases held resources. +func (c *WriteContext) Release() { + c.route.Release() + *c = WriteContext{} +} + +// WritePacketInfo is the properties of a packet that may be written. +type WritePacketInfo struct { + NetProto tcpip.NetworkProtocolNumber + LocalAddress, RemoteAddress tcpip.Address + MaxHeaderLength uint16 + RequiresTXTransportChecksum bool +} + +// PacketInfo returns the properties of a packet that will be written. +func (c *WriteContext) PacketInfo() WritePacketInfo { + return WritePacketInfo{ + NetProto: c.route.NetProto(), + LocalAddress: c.route.LocalAddress(), + RemoteAddress: c.route.RemoteAddress(), + MaxHeaderLength: c.route.MaxHeaderLength(), + RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(), + } +} + +// WritePacket attempts to write the packet. +func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { + pkt.Owner = c.owner + + if headerIncluded { + return c.route.WriteHeaderIncludedPacket(pkt) + } + + return c.route.WritePacket(stack.NetworkHeaderParams{ + Protocol: c.transProto, + TTL: c.ttl, + TOS: c.tos, + }, pkt) +} + +// AcquireContextForWrite acquires a WriteContext. +func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. + if opts.More { + return WriteContext{}, &tcpip.ErrInvalidOptionValue{} + } + + if e.State() == transport.DatagramEndpointStateClosed { + return WriteContext{}, &tcpip.ErrInvalidEndpointState{} + } + + if e.writeShutdown { + return WriteContext{}, &tcpip.ErrClosedForSend{} + } + + route := e.connectedRoute + if opts.To == nil { + // If the user doesn't specify a destination, they should have + // connected to another address. + if e.State() != transport.DatagramEndpointStateConnected { + return WriteContext{}, &tcpip.ErrDestinationRequired{} + } + + route.Acquire() + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicID := opts.To.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } + if e.info.BindNICID != 0 { + if nicID != 0 && nicID != e.info.BindNICID { + return WriteContext{}, &tcpip.ErrNoRoute{} + } + + nicID = e.info.BindNICID + } + + dst, netProto, err := e.checkV4MappedLocked(*opts.To) + if err != nil { + return WriteContext{}, err + } + + route, _, err = e.connectRoute(nicID, dst, netProto) + if err != nil { + return WriteContext{}, err + } + } + + if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { + route.Release() + return WriteContext{}, &tcpip.ErrBroadcastDisabled{} + } + + return WriteContext{ + transProto: e.transProto, + route: route, + ttl: calculateTTL(route, e.ttl, e.multicastTTL), + tos: e.sendTOS, + owner: e.owner, + }, nil +} + +// Disconnect disconnects the endpoint from its peer. +func (e *Endpoint) Disconnect() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.State() != transport.DatagramEndpointStateConnected { + return + } + + // Exclude ephemerally bound endpoints. + if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" { + e.info.ID = stack.TransportEndpointID{ + LocalAddress: e.info.ID.LocalAddress, + } + e.setEndpointState(transport.DatagramEndpointStateBound) + } else { + e.info.ID = stack.TransportEndpointID{} + e.setEndpointState(transport.DatagramEndpointStateInitial) + } + + e.connectedRoute.Release() + e.connectedRoute = nil +} + +// connectRoute establishes a route to the specified interface or the +// configured multicast interface if no interface is specified and the +// specified address is a multicast address. +func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { + localAddr := e.info.ID.LocalAddress + if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" + } + + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { + if nicID == 0 { + nicID = e.multicastNICID + } + if localAddr == "" && nicID == 0 { + localAddr = e.multicastAddr + } + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) + if err != nil { + return nil, 0, err + } + return r, nicID, nil +} + +// Connect connects the endpoint to the address. +func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error { + return nil + }) +} + +// ConnectAndThen connects the endpoint to the address and then calls the +// provided function. +// +// If the function returns an error, the endpoint's state does not change. The +// function will be called with the network protocol used to connect to the peer +// and the source and destination addresses that will be used to send traffic to +// the peer. +func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error { + addr.Port = 0 + + e.mu.Lock() + defer e.mu.Unlock() + + nicID := addr.NIC + switch e.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + if e.info.BindNICID == 0 { + break + } + + if nicID != 0 && nicID != e.info.BindNICID { + return &tcpip.ErrInvalidEndpointState{} + } + + nicID = e.info.BindNICID + default: + return &tcpip.ErrInvalidEndpointState{} + } + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + r, nicID, err := e.connectRoute(nicID, addr, netProto) + if err != nil { + return err + } + + id := stack.TransportEndpointID{ + LocalAddress: e.info.ID.LocalAddress, + RemoteAddress: r.RemoteAddress(), + } + if e.State() == transport.DatagramEndpointStateInitial { + id.LocalAddress = r.LocalAddress() + } + + if err := f(r.NetProto(), e.info.ID, id); err != nil { + return err + } + + if e.connectedRoute != nil { + // If the endpoint was previously connected then release any previous route. + e.connectedRoute.Release() + } + e.connectedRoute = r + e.info.ID = id + e.info.RegisterNICID = nicID + e.effectiveNetProto = netProto + e.setEndpointState(transport.DatagramEndpointStateConnected) + return nil +} + +// Shutdown shutsdown the endpoint. +func (e *Endpoint) Shutdown() tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + e.writeShutdown = true + return nil + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } +} + +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { + unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) + if err != nil { + return tcpip.FullAddress{}, 0, err + } + return unwrapped, netProto, nil +} + +func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) +} + +// Bind binds the endpoint to the address. +func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { + return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error { + return nil + }) +} + +// BindAndThen binds the endpoint to the address and then calls the provided +// function. +// +// If the function returns an error, the endpoint's state does not change. The +// function will be called with the bound network protocol and address. +func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error { + addr.Port = 0 + + e.mu.Lock() + defer e.mu.Unlock() + + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.State() != transport.DatagramEndpointStateInitial { + return &tcpip.ErrInvalidEndpointState{} + } + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + nicID := addr.NIC + if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { + nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr) + if nicID == 0 { + return &tcpip.ErrBadLocalAddress{} + } + } + + if err := f(netProto, addr.Addr); err != nil { + return err + } + + e.info.ID = stack.TransportEndpointID{ + LocalAddress: addr.Addr, + } + e.info.BindNICID = nicID + e.info.RegisterNICID = nicID + e.info.BindAddr = addr.Addr + e.effectiveNetProto = netProto + e.setEndpointState(transport.DatagramEndpointStateBound) + return nil +} + +// GetLocalAddress returns the address that the endpoint is bound to. +func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { + e.mu.RLock() + defer e.mu.RUnlock() + + addr := e.info.BindAddr + if e.State() == transport.DatagramEndpointStateConnected { + addr = e.connectedRoute.LocalAddress() + } + + return tcpip.FullAddress{ + NIC: e.info.RegisterNICID, + Addr: addr, + } +} + +// GetRemoteAddress returns the address that the endpoint is connected to. +func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.State() != transport.DatagramEndpointStateConnected { + return tcpip.FullAddress{}, false + } + + return tcpip.FullAddress{ + Addr: e.connectedRoute.RemoteAddress(), + NIC: e.info.RegisterNICID, + }, true +} + +// SetSockOptInt sets the socket option. +func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { + switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return &tcpip.ErrNotSupported{} + } + + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() + + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + case tcpip.IPv4TOSOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + } + + return nil +} + +// GetSockOptInt returns the socket option. +func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { + switch opt { + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + + case tcpip.MulticastTTLOption: + e.mu.Lock() + v := int(e.multicastTTL) + e.mu.Unlock() + return v, nil + + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + default: + return -1, &tcpip.ErrUnknownProtocolOption{} + } +} + +// SetSockOpt sets the socket option. +func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + switch v := opt.(type) { + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + defer e.mu.Unlock() + + fa := tcpip.FullAddress{Addr: v.InterfaceAddr} + fa, netProto, err := e.checkV4MappedLocked(fa) + if err != nil { + return err + } + nic := v.NIC + addr := fa.Addr + + if nic == 0 && addr == "" { + e.multicastAddr = "" + e.multicastNICID = 0 + break + } + + if nic != 0 { + if !e.stack.CheckNIC(nic) { + return &tcpip.ErrBadLocalAddress{} + } + } else { + nic = e.stack.CheckLocalAddress(0, netProto, addr) + if nic == 0 { + return &tcpip.ErrBadLocalAddress{} + } + } + + if e.info.BindNICID != 0 && e.info.BindNICID != nic { + return &tcpip.ErrInvalidEndpointState{} + } + + e.multicastNICID = nic + e.multicastAddr = addr + + case *tcpip.AddMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return &tcpip.ErrInvalidOptionValue{} + } + + nicID := v.NIC + + if v.InterfaceAddr.Unspecified() { + if nicID == 0 { + if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return &tcpip.ErrUnknownDevice{} + } + + memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.multicastMemberships[memToInsert]; ok { + return &tcpip.ErrPortInUse{} + } + + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships[memToInsert] = struct{}{} + + case *tcpip.RemoveMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return &tcpip.ErrInvalidOptionValue{} + } + + nicID := v.NIC + if v.InterfaceAddr.Unspecified() { + if nicID == 0 { + if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return &tcpip.ErrUnknownDevice{} + } + + memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.multicastMemberships[memToRemove]; !ok { + return &tcpip.ErrBadLocalAddress{} + } + + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + delete(e.multicastMemberships, memToRemove) + + case *tcpip.SocketDetachFilterOption: + return nil + } + return nil +} + +// GetSockOpt returns the socket option. +func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + switch o := opt.(type) { + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + *o = tcpip.MulticastInterfaceOption{ + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, + } + e.mu.Unlock() + + default: + return &tcpip.ErrUnknownProtocolOption{} + } + return nil +} + +// Info returns a copy of the endpoint info. +func (e *Endpoint) Info() stack.TransportEndpointInfo { + e.mu.RLock() + defer e.mu.RUnlock() + return e.info +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go new file mode 100644 index 000000000..858007156 --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -0,0 +1,56 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" +) + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *Endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + + e.stack = s + + for m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(fmt.Sprintf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err)) + } + } + + switch state := e.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound: + if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { + if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 { + panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress)) + } + } + case transport.DatagramEndpointStateConnected: + var err tcpip.Error + multicastLoop := e.ops.GetMulticastLoop() + e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) + if err != nil { + panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) + } + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go new file mode 100644 index 000000000..2c43eb66a --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint_test.go @@ -0,0 +1,209 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +func TestEndpointStateTransitions(t *testing.T) { + const ( + nicID = 1 + ) + + var ( + ipv4NICAddr = testutil.MustParse4("1.2.3.4") + ipv6NICAddr = testutil.MustParse6("a::1") + ipv4RemoteAddr = testutil.MustParse4("6.7.8.9") + ipv6RemoteAddr = testutil.MustParse6("b::1") + ) + + data := buffer.View([]byte{1, 2, 4, 5}) + v4Checker := func(t *testing.T, b buffer.View) { + checker.IPv4(t, b, + checker.SrcAddr(ipv4NICAddr), + checker.DstAddr(ipv4RemoteAddr), + checker.IPPayload(data), + ) + } + + v6Checker := func(t *testing.T, b buffer.View) { + checker.IPv6(t, b, + checker.SrcAddr(ipv6NICAddr), + checker.DstAddr(ipv6RemoteAddr), + checker.IPPayload(data), + ) + } + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + expectedMaxHeaderLength uint16 + expectedNetProto tcpip.NetworkProtocolNumber + expectedLocalAddr tcpip.Address + bindAddr tcpip.Address + expectedBoundAddr tcpip.Address + remoteAddr tcpip.Address + expectedRemoteAddr tcpip.Address + checker func(*testing.T, buffer.View) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + expectedMaxHeaderLength: header.IPv4MaximumHeaderSize, + expectedNetProto: ipv4.ProtocolNumber, + expectedLocalAddr: ipv4NICAddr, + bindAddr: header.IPv4AllSystems, + expectedBoundAddr: header.IPv4AllSystems, + remoteAddr: ipv4RemoteAddr, + expectedRemoteAddr: ipv4RemoteAddr, + checker: v4Checker, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + expectedMaxHeaderLength: header.IPv6FixedHeaderSize, + expectedNetProto: ipv6.ProtocolNumber, + expectedLocalAddr: ipv6NICAddr, + bindAddr: header.IPv6AllNodesMulticastAddress, + expectedBoundAddr: header.IPv6AllNodesMulticastAddress, + remoteAddr: ipv6RemoteAddr, + expectedRemoteAddr: ipv6RemoteAddr, + checker: v6Checker, + }, + { + name: "IPv4-mapped-IPv6", + netProto: ipv6.ProtocolNumber, + expectedMaxHeaderLength: header.IPv4MaximumHeaderSize, + expectedNetProto: ipv4.ProtocolNumber, + expectedLocalAddr: ipv4NICAddr, + bindAddr: testutil.MustParse6("::ffff:e000:0001"), + expectedBoundAddr: header.IPv4AllSystems, + remoteAddr: testutil.MustParse6("::ffff:0607:0809"), + expectedRemoteAddr: ipv4RemoteAddr, + checker: v4Checker, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID}, + {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID}, + }) + + var ops tcpip.SocketOptions + var ep network.Endpoint + ep.Init(s, test.netProto, udp.ProtocolNumber, &ops) + if state := ep.State(); state != transport.DatagramEndpointStateInitial { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial) + } + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + if state := ep.State(); state != transport.DatagramEndpointStateBound { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound) + } + if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" { + t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff) + } + if addr, connected := ep.GetRemoteAddress(); connected { + t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr) + } + + connectAddr := tcpip.FullAddress{Addr: test.remoteAddr} + if err := ep.Connect(connectAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", connectAddr, err) + } + if state := ep.State(); state != transport.DatagramEndpointStateConnected { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected) + } + if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" { + t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff) + } + if addr, connected := ep.GetRemoteAddress(); !connected { + t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr) + } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" { + t.Errorf("remote address mismatch (-want +got):\n%s", diff) + } + + ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("ep.AcquireContexForWrite({}): %s", err) + } + defer ctx.Release() + info := ctx.PacketInfo() + if diff := cmp.Diff(network.WritePacketInfo{ + NetProto: test.expectedNetProto, + LocalAddress: test.expectedLocalAddr, + RemoteAddress: test.expectedRemoteAddr, + MaxHeaderLength: test.expectedMaxHeaderLength, + RequiresTXTransportChecksum: true, + }, info); diff != "" { + t.Errorf("write packet info mismatch (-want +got):\n%s", diff) + } + if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(info.MaxHeaderLength), + Data: data.ToVectorisedView(), + }), false /* headerIncluded */); err != nil { + t.Fatalf("ctx.WritePacket(_, false): %s", err) + } + if pkt, ok := e.Read(); !ok { + t.Fatalf("expected packet to be read from link endpoint") + } else { + test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader())) + } + + ep.Close() + if state := ep.State(); state != transport.DatagramEndpointStateClosed { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed) + } + }) + } +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 8e7bb6c6e..89b4720aa 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -207,8 +207,52 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul return res, nil } -func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { - return 0, &tcpip.ErrInvalidOptionValue{} +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + if !ep.stack.PacketEndpointWriteSupported() { + return 0, &tcpip.ErrNotSupported{} + } + + ep.mu.Lock() + closed := ep.closed + nicID := ep.boundNIC + ep.mu.Unlock() + if closed { + return 0, &tcpip.ErrClosedForSend{} + } + + var remote tcpip.LinkAddress + proto := ep.netProto + if to := opts.To; to != nil { + remote = tcpip.LinkAddress(to.Addr) + + if n := to.NIC; n != 0 { + nicID = n + } + + if p := to.Port; p != 0 { + proto = tcpip.NetworkProtocolNumber(p) + } + } + + if nicID == 0 { + return 0, &tcpip.ErrInvalidOptionValue{} + } + + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. + payloadBytes := make(buffer.View, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, &tcpip.ErrBadBuffer{} + } + + if err := func() tcpip.Error { + if ep.cooked { + return ep.stack.WritePacketToRemote(nicID, remote, proto, payloadBytes.ToVectorisedView()) + } + return ep.stack.WriteRawPacket(nicID, proto, payloadBytes.ToVectorisedView()) + }(); err != nil { + return 0, err + } + return int64(len(payloadBytes)), nil } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index e729921db..5c688d286 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -34,17 +34,11 @@ func (p *packet) loadReceivedAt(nsec int64) { // saveData saves packet.data field. func (p *packet) saveData() buffer.VectorisedView { - // We cannot save p.data directly as p.data.views may alias to p.views, - // which is not allowed by state framework (in-struct pointer). return p.data.Clone(nil) } // loadData loads packet.data field. func (p *packet) loadData(data buffer.VectorisedView) { - // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization - // here because data.views is not guaranteed to be loaded by now. Plus, - // data.views will be allocated anyway so there really is little point - // of utilizing p.views for data.views. p.data = data } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 55854ba59..3bf6c0a8f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -281,6 +281,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return nil, nil, nil, &tcpip.ErrInvalidEndpointState{} } + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. payloadBytes := make([]byte, p.Len()) if _, err := io.ReadFull(p, payloadBytes); err != nil { return nil, nil, nil, &tcpip.ErrBadBuffer{} @@ -600,6 +601,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // We copy headers' underlying bytes because pkt.*Header may point to // the middle of a slice, and another struct may point to the "outer" // slice. Save/restore doesn't support overlapping slices and will fail. + // + // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports + // overlapping slices. var combinedVV buffer.VectorisedView if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index c3922bbe5..5148fe157 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -68,6 +68,7 @@ go_library( "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", + "//pkg/tcpip/internal/tcp", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index f8269efa6..03c9fafa1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -606,14 +606,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err MSS: calculateAdvertisedMSS(e.userMSS, route), } if opts.TS { - // Create a barely-sufficient endpoint to calculate the TSVal. - pseudoEndpoint := endpoint{ - TCPEndpointStateInner: stack.TCPEndpointStateInner{ - TSOffset: e.protocol.tsOffset(s.dstAddr, s.srcAddr), - }, - stack: e.stack, - } - synOpts.TSVal = pseudoEndpoint.tsValNow() + offset := e.protocol.tsOffset(s.dstAddr, s.srcAddr) + now := e.stack.Clock().NowMonotonic() + synOpts.TSVal = offset.TSVal(now) } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) fields := tcpFields{ diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0623ee8ed..d2b8f298f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2913,7 +2913,7 @@ func (e *endpoint) maybeEnableTimestamp(synOpts header.TCPSynOptions) { } func (e *endpoint) tsVal(now tcpip.MonotonicTime) uint32 { - return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + e.TSOffset + return e.TSOffset.TSVal(now) } func (e *endpoint) tsValNow() uint32 { @@ -2921,7 +2921,7 @@ func (e *endpoint) tsValNow() uint32 { } func (e *endpoint) elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration { - return time.Duration(e.tsVal(now)-tsEcr) * time.Millisecond + return e.TSOffset.Elapsed(now, tsEcr) } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b0ffd2429..e4410ad93 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" + "gvisor.dev/gvisor/pkg/tcpip/internal/tcp" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -158,7 +159,7 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, return stack.UnknownDestinationPacketHandled } -func (p *protocol) tsOffset(src, dst tcpip.Address) uint32 { +func (p *protocol) tsOffset(src, dst tcpip.Address) tcp.TSOffset { // Initialize a random tsOffset that will be added to the recentTS // everytime the timestamp is sent when the Timestamp option is enabled. // @@ -173,7 +174,7 @@ func (p *protocol) tsOffset(src, dst tcpip.Address) uint32 { // It never returns an error. _, _ = h.Write([]byte(src)) _, _ = h.Write([]byte(dst)) - return h.Sum32() + return tcp.NewTSOffset(h.Sum32()) } // replyWithReset replies to the given segment with a reset segment. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 9ce8fcae9..90e493978 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -477,7 +477,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 { + if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && (r.PendingBufUsed+int(segLen)) < int(rcvBufSize)>>2 { r.ep.rcvQueueInfo.rcvQueueMu.Lock() r.PendingBufUsed += s.segMemSize() r.ep.rcvQueueInfo.rcvQueueMu.Unlock() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 90b74a2a7..bc8708a5b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2128,6 +2128,211 @@ func TestFullWindowReceive(t *testing.T) { ) } +func TestSmallReceiveBufferReadiness(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + }) + + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } + + const nicID = 1 + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err) + } + + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x7f\x00\x00\x01"), + PrefixLen: 8, + } + if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err) + } + + { + subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") + if err != nil { + t.Fatalf("tcpip.NewSubnet failed: %s", err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: subnet, + NIC: nicID, + }, + }) + } + + listenerEntry, listenerCh := waiter.NewChannelEntry(nil) + var listenerWQ waiter.Queue + listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer listener.Close() + listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents) + defer listenerWQ.EventUnregister(&listenerEntry) + + if err := listener.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := listener.Listen(1); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + localAddress, err := listener.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress failed: %s", err) + } + + for i := 8; i > 0; i /= 2 { + size := int64(i << 10) + t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) { + var clientWQ waiter.Queue + client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer client.Close() + switch err := client.Connect(localAddress).(type) { + case nil: + t.Fatal("Connect returned nil error") + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect failed: %s", err) + } + + <-listenerCh + server, serverWQ, err := listener.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + defer server.Close() + + client.SocketOptions().SetReceiveBufferSize(size, true) + // Send buffer size doesn't seem to affect this test. + // server.SocketOptions().SetSendBufferSize(size, true) + + clientEntry, clientCh := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents) + defer clientWQ.EventUnregister(&clientEntry) + + serverEntry, serverCh := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverEntry, waiter.WritableEvents) + defer serverWQ.EventUnregister(&serverEntry) + + var total int64 + for { + var b [64 << 10]byte + var r bytes.Reader + r.Reset(b[:]) + switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + case nil: + t.Logf("wrote %d bytes", n) + total += n + continue + case *tcpip.ErrWouldBlock: + select { + case <-serverCh: + continue + case <-time.After(100 * time.Millisecond): + // Well and truly full. + t.Logf("send and receive queues are full") + } + default: + t.Fatalf("Write failed: %s", err) + } + break + } + t.Logf("wrote %d bytes in total", total) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(2) + go func() { + defer wg.Done() + + var b [64 << 10]byte + var r bytes.Reader + r.Reset(b[:]) + if err := func() error { + var total int64 + defer t.Logf("wrote %d bytes in total", total) + for r.Len() != 0 { + switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + case nil: + t.Logf("wrote %d bytes", n) + total += n + case *tcpip.ErrWouldBlock: + for { + t.Logf("waiting on server") + select { + case <-serverCh: + case <-time.After(time.Second): + if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 { + t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness) + } + continue + } + break + } + default: + return fmt.Errorf("server.Write failed: %s", err) + } + } + if err := server.Shutdown(tcpip.ShutdownWrite); err != nil { + return fmt.Errorf("server.Shutdown failed: %s", err) + } + t.Logf("server end shutdown done") + return nil + }(); err != nil { + t.Error(err) + } + }() + + go func() { + defer wg.Done() + + if err := func() error { + total := 0 + defer t.Logf("read %d bytes in total", total) + for { + switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { + case nil: + t.Logf("read %d bytes", res.Count) + total += res.Count + t.Logf("read total %d bytes till now", total) + case *tcpip.ErrClosedForReceive: + return nil + case *tcpip.ErrWouldBlock: + for { + t.Logf("waiting on client") + select { + case <-clientCh: + case <-time.After(time.Second): + if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 { + return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness) + } + continue + } + break + } + default: + return fmt.Errorf("client.Write failed: %s", err) + } + } + }(); err != nil { + t.Error(err) + } + }() + }) + } +} + // Test the stack receive window advertisement on receiving segments smaller than // segment overhead. It tests for the right edge of the window to not grow when // the endpoint is not being read from. diff --git a/pkg/tcpip/transport/transport.go b/pkg/tcpip/transport/transport.go new file mode 100644 index 000000000..4c2ae87f4 --- /dev/null +++ b/pkg/tcpip/transport/transport.go @@ -0,0 +1,16 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package transport supports transport protocols. +package transport diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index cdc344ab7..5cc7a2886 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -35,6 +35,8 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 108580508..4b6bdc3be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,8 +15,8 @@ package udp import ( + "fmt" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sync" @@ -25,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -40,36 +42,6 @@ type udpPacket struct { tos uint8 } -// EndpointState represents the state of a UDP endpoint. -type EndpointState tcpip.EndpointState - -// Endpoint states. Note that are represented in a netstack-specific manner and -// may not be meaningful externally. Specifically, they need to be translated to -// Linux's representation for these states if presented to userspace. -const ( - _ EndpointState = iota - StateInitial - StateBound - StateConnected - StateClosed -) - -// String implements fmt.Stringer. -func (s EndpointState) String() string { - switch s { - case StateInitial: - return "INITIAL" - case StateBound: - return "BOUND" - case StateConnected: - return "CONNECTING" - case StateClosed: - return "CLOSED" - default: - return "UNKNOWN" - } -} - // endpoint represents a UDP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -79,7 +51,6 @@ func (s EndpointState) String() string { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and do not @@ -87,6 +58,10 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats tcpip.TransportEndpointStats `state:"nosave"` + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -96,37 +71,19 @@ type endpoint struct { rcvBufSize int rcvClosed bool - // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - // state must be read/set using the EndpointState()/setEndpointState() - // methods. - state uint32 - route *stack.Route `state:"manual"` - dstPort uint16 - ttl uint8 - multicastTTL uint8 - multicastAddr tcpip.Address - multicastNICID tcpip.NICID - portFlags ports.Flags - lastErrorMu sync.Mutex `state:"nosave"` lastError tcpip.Error + // The following fields are protected by the mu mutex. + mu sync.RWMutex `state:"nosave"` + portFlags ports.Flags + // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID boundPortFlags ports.Flags - // sendTOS represents IPv4 TOS or IPv6 TrafficClass, - // applied while sending packets. Defaults to 0 as on Linux. - sendTOS uint8 - - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - - // multicastMemberships that need to be remvoed when the endpoint is - // closed. Protected by the mu mutex. - multicastMemberships map[multicastMembership]struct{} + readShutdown bool // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -136,55 +93,25 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool -} -// +stateify savable -type multicastMembership struct { - nicID tcpip.NICID - multicastAddr tcpip.Address + localPort uint16 + remotePort uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: header.UDPProtocolNumber, - }, + stack: s, waiterQueue: waiterQueue, - // RFC 1075 section 5.4 recommends a TTL of 1 for membership - // requests. - // - // RFC 5135 4.2.1 appears to assume that IGMP messages have a - // TTL of 1. - // - // RFC 5135 Appendix A defines TTL=1: A multicast source that - // wants its traffic to not traverse a router (e.g., leave a - // home network) may find it useful to send traffic with IP - // TTL=1. - // - // Linux defaults to TTL=1. - multicastTTL: 1, - multicastMemberships: make(map[multicastMembership]struct{}), - state: uint32(StateInitial), - uniqueID: s.UniqueID(), + uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetSendBufferSize(32*1024, false /* notify */) e.ops.SetReceiveBufferSize(32*1024, false /* notify */) + e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -200,20 +127,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue return e } -// setEndpointState updates the state of the endpoint to state atomically. This -// method is unexported as the only place we should update the state is in this -// package but we allow the state to be read freely without holding e.mu. -// -// Precondition: e.mu must be held to call this method. -func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32(&e.state, uint32(state)) -} - -// EndpointState() returns the current state of the endpoint. -func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32(&e.state)) -} - // UniqueID implements stack.TransportEndpoint. func (e *endpoint) UniqueID() uint64 { return e.uniqueID @@ -244,16 +157,22 @@ func (e *endpoint) Abort() { // associated with it. func (e *endpoint) Close() { e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.EndpointState() { - case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + e.mu.Unlock() + return + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + id := e.net.Info().ID + id.LocalPort = e.localPort + id.RemotePort = e.remotePort + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice) portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: id.LocalAddress, + Port: id.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: tcpip.FullAddress{}, @@ -261,13 +180,10 @@ func (e *endpoint) Close() { e.stack.ReleasePort(portRes) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } - for mem := range e.multicastMemberships { - e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) - } - e.multicastMemberships = nil - // Close the receive list and drain it. e.rcvMu.Lock() e.rcvClosed = true @@ -278,14 +194,9 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - if e.route != nil { - e.route.Release() - e.route = nil - } - - // Update the state. - e.setEndpointState(StateClosed) - + e.net.Shutdown() + e.net.Close() + e.readShutdown = true e.mu.Unlock() e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) @@ -359,19 +270,19 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult return res, nil } -// prepareForWrite prepares the endpoint for sending data. In particular, it -// binds it if it's still in the initial state. To do so, it must first +// prepareForWriteInner prepares the endpoint for sending data. In particular, +// it binds it if it's still in the initial state. To do so, it must first // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. // +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.EndpointState() { - case StateInitial: - case StateConnected: +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - case StateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -386,7 +297,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.EndpointState() != StateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -398,33 +309,6 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip return true, nil } -// connectRoute establishes a route to the specified interface or the -// configured multicast interface if no interface is specified and the -// specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { - localAddr := e.ID.LocalAddress - if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { - // A packet can only originate from a unicast address (i.e., an interface). - localAddr = "" - } - - if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicID == 0 { - nicID = e.multicastNICID - } - if localAddr == "" && nicID == 0 { - localAddr = e.multicastAddr - } - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) - if err != nil { - return nil, 0, err - } - return r, nicID, nil -} - // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { @@ -448,18 +332,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { +func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return udpPacketInfo{}, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(opts.To) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { return udpPacketInfo{}, err } @@ -469,49 +348,28 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions } } - route := e.route - dstPort := e.dstPort + dst, connected := e.net.GetRemoteAddress() + dst.Port = e.remotePort if opts.To != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := opts.To.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return udpPacketInfo{}, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - if opts.To.Port == 0 { // Port 0 is an invalid port to send to. return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{} } - dst, netProto, err := e.checkV4MappedLocked(*opts.To) - if err != nil { - return udpPacketInfo{}, err - } - - r, _, err := e.connectRoute(nicID, dst, netProto) - if err != nil { - return udpPacketInfo{}, err - } - defer r.Release() - - route = r - dstPort = dst.Port + dst = *opts.To + } else if !connected { + return udpPacketInfo{}, &tcpip.ErrDestinationRequired{} } - if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{} + ctx, err := e.net.AcquireContextForWrite(opts) + if err != nil { + return udpPacketInfo{}, err } + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { + ctx.Release() return udpPacketInfo{}, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { @@ -520,50 +378,25 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions if so.GetRecvError() { so.QueueLocalErr( &tcpip.ErrMessageTooLong{}, - route.NetProto(), + e.net.NetProto(), header.UDPMaximumPacketSize, - tcpip.FullAddress{ - NIC: route.NICID(), - Addr: route.RemoteAddress(), - Port: dstPort, - }, + dst, v, ) } + ctx.Release() return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} } - ttl := e.ttl - useDefaultTTL := ttl == 0 - if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { - ttl = e.multicastTTL - // Multicast allows a 0 TTL. - useDefaultTTL = false - } - return udpPacketInfo{ - route: route, - data: buffer.View(v), - localPort: e.ID.LocalPort, - remotePort: dstPort, - ttl: ttl, - useDefaultTTL: useDefaultTTL, - tos: e.sendTOS, - owner: e.owner, - noChecksum: e.SocketOptions().GetNoChecksum(), + ctx: ctx, + data: v, + localPort: e.localPort, + remotePort: dst.Port, }, nil } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if err := e.LastError(); err != nil { - return 0, err - } - - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - // Do not hold lock when sending as loopback is synchronous and if the UDP // datagram ends up generating an ICMP response then it can result in a // deadlock where the ICMP response handling ends up acquiring this endpoint's @@ -574,15 +407,53 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read // locking is prohibited. - u, err := e.buildUDPPacketInfo(p, opts) - if err != nil { + + if err := e.LastError(); err != nil { return 0, err } - n, err := u.send() + + udpInfo, err := e.prepareForWrite(p, opts) if err != nil { return 0, err } - return int64(n), nil + defer udpInfo.ctx.Release() + + pktInfo := udpInfo.ctx.PacketInfo() + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength), + Data: udpInfo.data.ToVectorisedView(), + }) + + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = ProtocolNumber + + length := uint16(pkt.Size()) + udp.Encode(&header.UDPFields{ + SrcPort: udpInfo.localPort, + DstPort: udpInfo.remotePort, + Length: length, + }) + + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if pktInfo.RequiresTXTransportChecksum && + (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) { + udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine( + header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length), + pkt.Data().AsRange().Checksum(), + ))) + } + if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + e.stack.Stats().UDP.PacketSendErrors.Increment() + return 0, err + } + + // Track count of packets sent. + e.stack.Stats().UDP.PacketsSent.Increment() + return int64(len(udpInfo.data)), nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler. @@ -601,36 +472,7 @@ func (e *endpoint) OnReusePortSet(v bool) { // SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.MTUDiscoverOption: - // Return not supported if the value is not disabling path - // MTU discovery. - if v != tcpip.PMTUDiscoveryDont { - return &tcpip.ErrNotSupported{} - } - - case tcpip.MulticastTTLOption: - e.mu.Lock() - e.multicastTTL = uint8(v) - e.mu.Unlock() - - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - case tcpip.IPv4TOSOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - } - - return nil + return e.net.SetSockOptInt(opt, v) } var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) @@ -642,145 +484,12 @@ func (e *endpoint) HasNIC(id int32) bool { // SetSockOpt implements tcpip.Endpoint. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { - switch v := opt.(type) { - case *tcpip.MulticastInterfaceOption: - e.mu.Lock() - defer e.mu.Unlock() - - fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedLocked(fa) - if err != nil { - return err - } - nic := v.NIC - addr := fa.Addr - - if nic == 0 && addr == "" { - e.multicastAddr = "" - e.multicastNICID = 0 - break - } - - if nic != 0 { - if !e.stack.CheckNIC(nic) { - return &tcpip.ErrBadLocalAddress{} - } - } else { - nic = e.stack.CheckLocalAddress(0, netProto, addr) - if nic == 0 { - return &tcpip.ErrBadLocalAddress{} - } - } - - if e.BindNICID != 0 && e.BindNICID != nic { - return &tcpip.ErrInvalidEndpointState{} - } - - e.multicastNICID = nic - e.multicastAddr = addr - - case *tcpip.AddMembershipOption: - if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return &tcpip.ErrInvalidOptionValue{} - } - - nicID := v.NIC - - if v.InterfaceAddr.Unspecified() { - if nicID == 0 { - if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil { - nicID = r.NICID() - r.Release() - } - } - } else { - nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) - } - if nicID == 0 { - return &tcpip.ErrUnknownDevice{} - } - - memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - - e.mu.Lock() - defer e.mu.Unlock() - - if _, ok := e.multicastMemberships[memToInsert]; ok { - return &tcpip.ErrPortInUse{} - } - - if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { - return err - } - - e.multicastMemberships[memToInsert] = struct{}{} - - case *tcpip.RemoveMembershipOption: - if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return &tcpip.ErrInvalidOptionValue{} - } - - nicID := v.NIC - if v.InterfaceAddr.Unspecified() { - if nicID == 0 { - if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil { - nicID = r.NICID() - r.Release() - } - } - } else { - nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) - } - if nicID == 0 { - return &tcpip.ErrUnknownDevice{} - } - - memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - - e.mu.Lock() - defer e.mu.Unlock() - - if _, ok := e.multicastMemberships[memToRemove]; !ok { - return &tcpip.ErrBadLocalAddress{} - } - - if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { - return err - } - - delete(e.multicastMemberships, memToRemove) - - case *tcpip.SocketDetachFilterOption: - return nil - } - return nil + return e.net.SetSockOpt(opt) } // GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { - case tcpip.IPv4TOSOption: - e.mu.RLock() - v := int(e.sendTOS) - e.mu.RUnlock() - return v, nil - - case tcpip.IPv6TrafficClassOption: - e.mu.RLock() - v := int(e.sendTOS) - e.mu.RUnlock() - return v, nil - - case tcpip.MTUDiscoverOption: - // The only supported setting is path MTU discovery disabled. - return tcpip.PMTUDiscoveryDont, nil - - case tcpip.MulticastTTLOption: - e.mu.Lock() - v := int(e.multicastTTL) - e.mu.Unlock() - return v, nil - case tcpip.ReceiveQueueSizeOption: v := 0 e.rcvMu.Lock() @@ -791,108 +500,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.mu.Lock() - v := int(e.ttl) - e.mu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } // GetSockOpt implements tcpip.Endpoint. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { - switch o := opt.(type) { - case *tcpip.MulticastInterfaceOption: - e.mu.Lock() - *o = tcpip.MulticastInterfaceOption{ - NIC: e.multicastNICID, - InterfaceAddr: e.multicastAddr, - } - e.mu.Unlock() - - default: - return &tcpip.ErrUnknownProtocolOption{} - } - return nil + return e.net.GetSockOpt(opt) } -// udpPacketInfo contains all information required to send a UDP packet. -// -// This should be used as a value-only type, which exists in order to simplify -// return value syntax. It should not be exported or extended. +// udpPacketInfo holds information needed to send a UDP packet. type udpPacketInfo struct { - route *stack.Route - data buffer.View - localPort uint16 - remotePort uint16 - ttl uint8 - useDefaultTTL bool - tos uint8 - owner tcpip.PacketOwner - noChecksum bool -} - -// send sends the given packet. -func (u *udpPacketInfo) send() (int, tcpip.Error) { - vv := u.data.ToVectorisedView() - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()), - Data: vv, - }) - pkt.Owner = u.owner - - // Initialize the UDP header. - udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - pkt.TransportProtocolNumber = ProtocolNumber - - length := uint16(pkt.Size()) - udp.Encode(&header.UDPFields{ - SrcPort: u.localPort, - DstPort: u.remotePort, - Length: length, - }) - - // Set the checksum field unless TX checksum offload is enabled. - // On IPv4, UDP checksum is optional, and a zero value indicates the - // transmitter skipped the checksum generation (RFC768). - // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if u.route.RequiresTXTransportChecksum() && - (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) { - xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length) - for _, v := range vv.Views() { - xsum = header.Checksum(v, xsum) - } - udp.SetChecksum(^udp.CalculateChecksum(xsum)) - } - - if u.useDefaultTTL { - u.ttl = u.route.DefaultTTL() - } - if err := u.route.WritePacket(stack.NetworkHeaderParams{ - Protocol: ProtocolNumber, - TTL: u.ttl, - TOS: u.tos, - }, pkt); err != nil { - u.route.Stats().UDP.PacketSendErrors.Increment() - return 0, err - } - - // Track count of packets sent. - u.route.Stats().UDP.PacketsSent.Increment() - return len(u.data), nil -} - -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil + ctx network.WriteContext + data buffer.View + localPort uint16 + remotePort uint16 } // Disconnect implements tcpip.Endpoint. @@ -900,7 +523,7 @@ func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.EndpointState() != StateConnected { + if e.net.State() != transport.DatagramEndpointStateConnected { return nil } var ( @@ -913,26 +536,28 @@ func (e *endpoint) Disconnect() tcpip.Error { boundPortFlags := e.boundPortFlags // Exclude ephemerally bound endpoints. - if e.BindNICID != 0 || e.ID.LocalAddress == "" { + info := e.net.Info() + info.ID.LocalPort = e.localPort + info.ID.RemotePort = e.remotePort + if info.BindNICID != 0 || info.ID.LocalAddress == "" { var err tcpip.Error id = stack.TransportEndpointID{ - LocalPort: e.ID.LocalPort, - LocalAddress: e.ID.LocalAddress, + LocalPort: info.ID.LocalPort, + LocalAddress: info.ID.LocalAddress, } id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { return err } - e.setEndpointState(StateBound) boundPortFlags = e.boundPortFlags } else { - if e.ID.LocalPort != 0 { + if info.ID.LocalPort != 0 { // Release the ephemeral port. portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: info.ID.LocalAddress, + Port: info.ID.LocalPort, Flags: boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: tcpip.FullAddress{}, @@ -940,15 +565,14 @@ func (e *endpoint) Disconnect() tcpip.Error { e.stack.ReleasePort(portRes) e.boundPortFlags = ports.Flags{} } - e.setEndpointState(StateInitial) } - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) - e.ID = id + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice) e.boundBindToDevice = btd - e.route.Release() - e.route = nil - e.dstPort = 0 + e.localPort = id.LocalPort + e.remotePort = id.RemotePort + + e.net.Disconnect() return nil } @@ -958,88 +582,48 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - var localPort uint16 - switch e.EndpointState() { - case StateInitial: - case StateBound, StateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } - - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.localPort + nextID.RemotePort = addr.Port + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv4ProtocolNumber, + header.IPv6ProtocolNumber, + } } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } - - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - r, nicID, err := e.connectRoute(nicID, addr, netProto) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: e.ID.LocalAddress, - LocalPort: localPort, - RemotePort: addr.Port, - RemoteAddress: r.RemoteAddress(), - } + oldPortFlags := e.boundPortFlags - if e.EndpointState() == StateInitial { - id.LocalAddress = r.LocalAddress() - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv4ProtocolNumber, - header.IPv6ProtocolNumber, + nextID, btd, err := e.registerWithStack(netProtos, nextID) + if err != nil { + return err } - } - oldPortFlags := e.boundPortFlags + // Remove the old registration. + if e.localPort != 0 { + previousID.LocalPort = e.localPort + previousID.RemotePort = e.remotePort + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice) + } - id, btd, err := e.registerWithStack(netProtos, id) + e.localPort = nextID.LocalPort + e.remotePort = nextID.RemotePort + e.boundBindToDevice = btd + e.effectiveNetProtos = netProtos + return nil + }) if err != nil { - r.Release() return err } - // Remove the old registration. - if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) - } - - e.ID = id - e.boundBindToDevice = btd - if e.route != nil { - // If the endpoint was already connected then make sure we release the - // previous route. - e.route.Release() - } - e.route = r - e.dstPort = addr.Port - e.RegisterNICID = nicID - e.effectiveNetProtos = netProtos - - e.setEndpointState(StateConnected) - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() - return nil } @@ -1054,15 +638,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - // A socket in the bound state can still receive multicast messages, - // so we need to notify waiters on shutdown. - if state := e.EndpointState(); state != StateBound && state != StateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } - e.shutdownFlags |= flags + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } + } if flags&tcpip.ShutdownRead != 0 { + e.readShutdown = true + e.rcvMu.Lock() wasClosed := e.rcvClosed e.rcvClosed = true @@ -1088,7 +680,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - if e.ID.LocalPort == 0 { + if e.localPort == 0 { portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, @@ -1126,56 +718,43 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.EndpointState() != StateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{boundNetProto} + if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } } - } - nicID := addr.NIC - if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { - // A local unicast address was specified, verify that it's valid. - nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicID == 0 { - return &tcpip.ErrBadLocalAddress{} + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: boundAddr, + } + id, btd, err := e.registerWithStack(netProtos, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, btd, err := e.registerWithStack(netProtos, id) + e.localPort = id.LocalPort + e.boundBindToDevice = btd + e.effectiveNetProtos = netProtos + return nil + }) if err != nil { return err } - e.ID = id - e.boundBindToDevice = btd - e.RegisterNICID = nicID - e.effectiveNetProtos = netProtos - - // Mark endpoint as bound. - e.setEndpointState(StateBound) - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() - return nil } @@ -1190,9 +769,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { return err } - // Save the effective NICID generated by bindLocked. - e.BindNICID = e.RegisterNICID - return nil } @@ -1201,16 +777,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - addr := e.ID.LocalAddress - if e.EndpointState() == StateConnected { - addr = e.route.LocalAddress() - } - - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: addr, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.localPort + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -1218,15 +787,13 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.EndpointState() != StateConnected || e.dstPort == 0 { + addr, connected := e.net.GetRemoteAddress() + if !connected || e.remotePort == 0 { return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + addr.Port = e.remotePort + return addr, nil } // Readiness returns the current readiness of the endpoint. For example, if @@ -1376,19 +943,20 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p payload = udp.Payload() } + id := e.net.Info().ID e.SocketOptions().QueueErr(&tcpip.SockError{ Err: err, Cause: transErr, Payload: payload, Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: id.RemoteAddress, + Port: e.remotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: id.LocalAddress, + Port: e.localPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -1403,7 +971,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // TODO(gvisor.dev/issues/5270): Handle all transport errors. switch transErr.Kind() { case stack.DestinationPortUnreachableTransportError: - if e.EndpointState() == StateConnected { + if e.net.State() == transport.DatagramEndpointStateConnected { e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) } } @@ -1411,16 +979,17 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // State implements tcpip.Endpoint. func (e *endpoint) State() uint32 { - return uint32(e.EndpointState()) + return uint32(e.net.State()) } // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() - return &ret + defer e.mu.RUnlock() + info := e.net.Info() + info.ID.LocalPort = e.localPort + info.ID.RemotePort = e.remotePort + return &info } // Stats returns a pointer to the endpoint stats. @@ -1431,13 +1000,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements tcpip.Endpoint. func (*endpoint) Wait() {} -func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) -} - // SetOwner implements tcpip.Endpoint. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // SocketOptions implements tcpip.Endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 1f638c3f6..2ff8b0482 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,12 +15,13 @@ package udp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -35,17 +36,11 @@ func (p *udpPacket) loadReceivedAt(nsec int64) { // saveData saves udpPacket.data field. func (p *udpPacket) saveData() buffer.VectorisedView { - // We cannot save p.data directly as p.data.views may alias to p.views, - // which is not allowed by state framework (in-struct pointer). return p.data.Clone(nil) } // loadData loads udpPacket.data field. func (p *udpPacket) loadData(data buffer.VectorisedView) { - // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization - // here because data.views is not guaranteed to be loaded by now. Plus, - // data.views will be allocated anyway so there really is little point - // of utilizing p.views for data.views. p.data = data } @@ -66,50 +61,28 @@ func (e *endpoint) Resume(s *stack.Stack) { e.mu.Lock() defer e.mu.Unlock() + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - for m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } - - state := e.EndpointState() - if state != StateBound && state != StateConnected { - return - } - - netProto := e.effectiveNetProtos[0] - // Connect() and bindLocked() both assert - // - // netProto == header.IPv6ProtocolNumber - // - // before creating a multi-entry effectiveNetProtos. - if len(e.effectiveNetProtos) > 1 { - netProto = header.IPv6ProtocolNumber - } - - var err tcpip.Error - if state == StateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + var err tcpip.Error + id := e.net.Info().ID + id.LocalPort = e.localPort + id.RemotePort = e.remotePort + id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { panic(err) } - } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound - // A local unicast address is specified, verify that it's valid. - if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - // Our saved state had a port, but we don't actually have a - // reservation. We need to remove the port from our state, but still - // pass it to the reservation machinery. - id := e.ID - e.ID.LocalPort = 0 - e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) - if err != nil { - panic(err) + e.localPort = id.LocalPort + e.remotePort = id.RemotePort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 7c357cb09..7238fc019 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -70,28 +70,29 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) + ep.mu.Lock() + defer ep.mu.Unlock() + netHdr := r.pkt.Network() - route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) - if err != nil { + if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil { + return nil, err + } + + if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil { return nil, err } - ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() - route.Release() return nil, err } - ep.ID = r.id - ep.route = route - ep.dstPort = r.id.RemotePort + ep.localPort = r.id.LocalPort + ep.remotePort = r.id.RemotePort ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} - ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = uint32(StateConnected) - ep.rcvMu.Lock() ep.rcvReady = true ep.rcvMu.Unlock() diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 3f667cd74..1dd0048ac 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1089,13 +1089,14 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st return inet.NewRootNamespace(hostinet.NewStack(), nil), nil case config.NetworkNone, config.NetworkSandbox: - s, err := newEmptySandboxNetworkStack(clock, uniqueID) + s, err := newEmptySandboxNetworkStack(clock, uniqueID, conf.AllowPacketEndpointWrite) if err != nil { return nil, err } creator := &sandboxNetstackCreator{ - clock: clock, - uniqueID: uniqueID, + clock: clock, + uniqueID: uniqueID, + allowPacketEndpointWrite: conf.AllowPacketEndpointWrite, } return inet.NewRootNamespace(s, creator), nil @@ -1105,7 +1106,7 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st } -func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { +func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID, allowPacketEndpointWrite bool) (inet.Stack, error) { netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol} transProtos := []stack.TransportProtocolFactory{ tcp.NewProtocol, @@ -1121,9 +1122,10 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in HandleLocal: true, // Enable raw sockets for users with sufficient // privileges. - RawFactory: raw.EndpointFactory{}, - UniqueID: uniqueID, - DefaultIPTables: netfilter.DefaultLinuxTables, + RawFactory: raw.EndpointFactory{}, + AllowPacketEndpointWrite: allowPacketEndpointWrite, + UniqueID: uniqueID, + DefaultIPTables: netfilter.DefaultLinuxTables, })} // Enable SACK Recovery. @@ -1160,13 +1162,14 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in // // +stateify savable type sandboxNetstackCreator struct { - clock tcpip.Clock - uniqueID stack.UniqueID + clock tcpip.Clock + uniqueID stack.UniqueID + allowPacketEndpointWrite bool } // CreateStack implements kernel.NetworkStackCreator.CreateStack. func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) { - s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID) + s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID, f.allowPacketEndpointWrite) if err != nil { return nil, err } diff --git a/runsc/config/config.go b/runsc/config/config.go index 2f52863ff..2ce8cc006 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -86,6 +86,9 @@ type Config struct { // capabilities. EnableRaw bool `flag:"net-raw"` + // AllowPacketEndpointWrite enables write operations on packet endpoints. + AllowPacketEndpointWrite bool `flag:"TESTONLY-allow-packet-endpoint-write"` + // HardwareGSO indicates that hardware segmentation offload is enabled. HardwareGSO bool `flag:"gso"` diff --git a/runsc/config/flags.go b/runsc/config/flags.go index 85507902a..cc5aba474 100644 --- a/runsc/config/flags.go +++ b/runsc/config/flags.go @@ -92,6 +92,7 @@ func RegisterFlags() { // Test flags, not to be used outside tests, ever. flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.") + flag.Bool("TESTONLY-allow-packet-endpoint-write", false, "TEST ONLY; do not ever use! Used for tests to allow writes on packet sockets.") }) } diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index ea83bbe72..49f41c887 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -28,10 +28,10 @@ #include <iostream> #include <unordered_map> +#include "absl/strings/str_format.h" #include "include/grpcpp/security/server_credentials.h" #include "include/grpcpp/server_builder.h" #include "include/grpcpp/server_context.h" -#include "absl/strings/str_format.h" #include "test/packetimpact/proto/posix_server.grpc.pb.h" #include "test/packetimpact/proto/posix_server.pb.h" diff --git a/test/runner/main.go b/test/runner/main.go index 34e9c6279..2cab8d2d4 100644 --- a/test/runner/main.go +++ b/test/runner/main.go @@ -170,6 +170,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { "-network", *network, "-log-format=text", "-TESTONLY-unsafe-nonroot=true", + "-TESTONLY-allow-packet-endpoint-write=true", "-net-raw=true", fmt.Sprintf("-panic-signal=%d", unix.SIGTERM), "-watchdog-action=panic", diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 01ee432cb..b06b3d233 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -3293,9 +3293,12 @@ cc_library( ], deps = [ ":unix_domain_socket_test_util", + "//test/util:file_descriptor", + "//test/util:memory_util", "//test/util:socket_util", "@com_google_absl//absl/strings", gtest, + "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", ], diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc index 8202d35fa..a2cc59e83 100644 --- a/test/syscalls/linux/eventfd.cc +++ b/test/syscalls/linux/eventfd.cc @@ -149,31 +149,6 @@ TEST(EventfdTest, BigWriteBigRead) { EXPECT_EQ(l[0], 1); } -TEST(EventfdTest, SpliceFromPipePartialSucceeds) { - int pipes[2]; - ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds()); - const FileDescriptor pipe_rfd(pipes[0]); - const FileDescriptor pipe_wfd(pipes[1]); - constexpr uint64_t kVal{1}; - - FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK)); - - uint64_t event_array[2]; - event_array[0] = kVal; - event_array[1] = kVal; - ASSERT_THAT(write(pipe_wfd.get(), event_array, sizeof(event_array)), - SyscallSucceedsWithValue(sizeof(event_array))); - EXPECT_THAT(splice(pipe_rfd.get(), /*__offin=*/nullptr, efd.get(), - /*__offout=*/nullptr, sizeof(event_array[0]) + 1, - SPLICE_F_NONBLOCK), - SyscallSucceedsWithValue(sizeof(event_array[0]))); - - uint64_t val; - ASSERT_THAT(read(efd.get(), &val, sizeof(val)), - SyscallSucceedsWithValue(sizeof(val))); - EXPECT_EQ(val, kVal); -} - // NotifyNonZero is inherently racy, so random save is disabled. TEST(EventfdTest, NotifyNonZero) { // Waits will time out at 10 seconds. diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc index f6b78989b..e2622232d 100644 --- a/test/syscalls/linux/inotify.cc +++ b/test/syscalls/linux/inotify.cc @@ -1849,34 +1849,6 @@ TEST(Inotify, SpliceOnWatchTarget) { })); } -TEST(Inotify, SpliceOnInotifyFD) { - int pipefds[2]; - ASSERT_THAT(pipe2(pipefds, O_NONBLOCK), SyscallSucceeds()); - - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - root.path(), "some content", TempPath::kDefaultFileMode)); - - const FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); - const int watcher = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - - char buf; - EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - - EXPECT_THAT(splice(fd.get(), nullptr, pipefds[1], nullptr, - sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK), - SyscallSucceedsWithValue(sizeof(struct inotify_event))); - - const FileDescriptor read_fd(pipefds[0]); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get())); - ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)})); -} - // Watches on a parent should not be triggered by actions on a hard link to one // of its children that has a different parent. TEST(Inotify, LinkOnOtherParent) { diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index ca4ab0aad..bfa5d179a 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -285,14 +285,6 @@ TEST_P(CookedPacketTest, Send) { memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr)); memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage)); - // We don't implement writing to packet sockets on gVisor. - if (IsRunningOnGvisor()) { - ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, - reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), - SyscallFailsWithErrno(EINVAL)); - GTEST_SKIP(); - } - // Send it. ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 61714d1da..e57c60ffa 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -296,14 +296,6 @@ TEST_P(RawPacketTest, Send) { memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage)); - // We don't implement writing to packet sockets on gVisor. - if (IsRunningOnGvisor()) { - ASSERT_THAT(sendto(s_, send_buf, sizeof(send_buf), 0, - reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), - SyscallFailsWithErrno(EINVAL)); - GTEST_SKIP(); - } - // Send it. ASSERT_THAT(sendto(s_, send_buf, sizeof(send_buf), 0, reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index bea4ee71c..9bd3bd5e8 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -208,38 +208,6 @@ TEST(SendFileTest, SendAndUpdateFileOffset) { absl::string_view(actual, kHalfDataSize)); } -TEST(SendFileTest, SendToDevZeroAndUpdateFileOffset) { - // Create temp files. - // Test input string length must be > 2 AND even. - constexpr char kData[] = "The slings and arrows of outrageous fortune,"; - constexpr int kDataSize = sizeof(kData) - 1; - constexpr int kHalfDataSize = kDataSize / 2; - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); - - // Open the input file as read only. - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); - - // Open /dev/zero as write only. - const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); - - // Send data and verify that sendfile returns the correct value. - int bytes_sent; - EXPECT_THAT( - bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize), - SyscallSucceedsWithValue(kHalfDataSize)); - - char actual[kHalfDataSize]; - // Verify that the input file offset has been updated. - ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent), - SyscallSucceedsWithValue(kHalfDataSize)); - EXPECT_EQ( - absl::string_view(kData + kDataSize - bytes_sent, kDataSize - bytes_sent), - absl::string_view(actual, kHalfDataSize)); -} - TEST(SendFileTest, SendAndUpdateFileOffsetFromNonzeroStartingPoint) { // Create temp files. // Test input string length must be > 2 AND divisible by 4. @@ -609,23 +577,6 @@ TEST(SendFileTest, SendPipeBlocks) { SyscallSucceedsWithValue(kDataSize)); } -TEST(SendFileTest, SendToSpecialFile) { - // Create temp file. - const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - - const FileDescriptor inf = - ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); - constexpr int kSize = 0x7ff; - ASSERT_THAT(ftruncate(inf.get(), kSize), SyscallSucceeds()); - - auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); - - // eventfd can accept a number of bytes which is a multiple of 8. - EXPECT_THAT(sendfile(eventfd.get(), inf.get(), nullptr, 0xfffff), - SyscallSucceedsWithValue(kSize & (~7))); -} - TEST(SendFileTest, SendFileToPipe) { // Create temp file. constexpr char kData[] = "<insert-quote-here>"; @@ -672,57 +623,6 @@ TEST(SendFileTest, SendFileToSelf) { SyscallSucceedsWithValue(kSendfileSize)); } -static volatile int signaled = 0; -void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; } - -TEST(SendFileTest, ToEventFDDoesNotSpin) { - FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); - - // Write the maximum value of an eventfd to a file. - const uint64_t kMaxEventfdValue = 0xfffffffffffffffe; - const auto tempfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const auto tempfd = ASSERT_NO_ERRNO_AND_VALUE(Open(tempfile.path(), O_RDWR)); - ASSERT_THAT( - pwrite(tempfd.get(), &kMaxEventfdValue, sizeof(kMaxEventfdValue), 0), - SyscallSucceedsWithValue(sizeof(kMaxEventfdValue))); - - // Set the eventfd's value to 1. - const uint64_t kOne = 1; - ASSERT_THAT(write(efd.get(), &kOne, sizeof(kOne)), - SyscallSucceedsWithValue(sizeof(kOne))); - - // Set up signal handler. - struct sigaction sa = {}; - sa.sa_sigaction = SigUsr1Handler; - sa.sa_flags = SA_SIGINFO; - const auto cleanup_sigact = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa)); - - // Send SIGUSR1 to this thread in 1 second. - struct sigevent sev = {}; - sev.sigev_notify = SIGEV_THREAD_ID; - sev.sigev_signo = SIGUSR1; - sev.sigev_notify_thread_id = gettid(); - auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); - struct itimerspec its = {}; - its.it_value = absl::ToTimespec(absl::Seconds(1)); - DisableSave ds; // Asserting an EINTR. - ASSERT_NO_ERRNO(timer.Set(0, its)); - - // Sendfile from tempfd to the eventfd. Since the eventfd is not already at - // its maximum value, the eventfd is "ready for writing"; however, since the - // eventfd's existing value plus the new value would exceed the maximum, the - // write should internally fail with EWOULDBLOCK. In this case, sendfile() - // should block instead of spinning, and eventually be interrupted by our - // timer. See b/172075629. - EXPECT_THAT( - sendfile(efd.get(), tempfd.get(), nullptr, sizeof(kMaxEventfdValue)), - SyscallFailsWithErrno(EINTR)); - - // Signal should have been handled. - EXPECT_EQ(signaled, 1); -} - } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index cf96b2075..43433eaae 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -27,7 +27,10 @@ #include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/memory_util.h" #include "test/util/socket_util.h" +#include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -268,6 +271,18 @@ TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) { } } +// Repro for b/196804997. +TEST_P(UnixSocketPairTest, SendFromMmapBeyondEof) { + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + Mmap(nullptr, kPageSize, PROT_READ, MAP_SHARED, fd.get(), 0)); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + ASSERT_THAT(send(sockets->first_fd(), m.ptr(), m.len(), 0), + SyscallFailsWithErrno(EFAULT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 6e9f70f8c..2f3cfc3f3 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -181,6 +181,21 @@ TEST_P(StreamUnixSocketPairTest, SetSocketSendBuf) { ASSERT_EQ(quarter_sz, val); } +TEST_P(StreamUnixSocketPairTest, SendBufferOverflow) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + auto s = sockets->first_fd(); + + constexpr int kBufSz = 4096; + std::vector<char> buf(kBufSz * 4); + ASSERT_THAT(RetryEINTR(send)(s, buf.data(), buf.size(), MSG_DONTWAIT), + SyscallSucceeds()); + // The new buffer size should be smaller that the amount of data in the queue. + ASSERT_THAT(setsockopt(s, SOL_SOCKET, SO_SNDBUF, &kBufSz, sizeof(kBufSz)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(send)(s, buf.data(), buf.size(), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + TEST_P(StreamUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int sock = sockets->first_fd(); diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index c85f6da0b..4a10ae8d2 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -195,81 +195,6 @@ TEST(SpliceTest, PipeOffsets) { SyscallFailsWithErrno(ESPIPE)); } -// Event FDs may be used with splice without an offset. -TEST(SpliceTest, FromEventFD) { - // Open the input eventfd with an initial value so that it is readable. - constexpr uint64_t kEventFDValue = 1; - int efd; - ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds()); - const FileDescriptor in_fd(efd); - - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Splice 8-byte eventfd value to pipe. - constexpr int kEventFDSize = 8; - EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), - SyscallSucceedsWithValue(kEventFDSize)); - - // Contents should be equal. - std::vector<char> rbuf(kEventFDSize); - ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), - SyscallSucceedsWithValue(kEventFDSize)); - EXPECT_EQ(memcmp(rbuf.data(), &kEventFDValue, rbuf.size()), 0); -} - -// Event FDs may not be used with splice with an offset. -TEST(SpliceTest, FromEventFDOffset) { - int efd; - ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor in_fd(efd); - - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Attempt to splice 8-byte eventfd value to pipe with offset. - // - // This is not allowed because eventfd doesn't support pread. - constexpr int kEventFDSize = 8; - loff_t in_off = 0; - EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - -// Event FDs may not be used with splice with an offset. -TEST(SpliceTest, ToEventFDOffset) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - const FileDescriptor wfd(fds[1]); - - // Fill with a value. - constexpr int kEventFDSize = 8; - std::vector<char> buf(kEventFDSize); - buf[0] = 1; - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kEventFDSize)); - - int efd; - ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor out_fd(efd); - - // Attempt to splice 8-byte eventfd value to pipe with offset. - // - // This is not allowed because eventfd doesn't support pwrite. - loff_t out_off = 0; - EXPECT_THAT( - splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0), - SyscallFailsWithErrno(EINVAL)); -} - TEST(SpliceTest, ToPipe) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -852,34 +777,6 @@ TEST(SpliceTest, FromPipeMaxFileSize) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } -TEST(SpliceTest, FromPipeToDevZero) { - // Create a new pipe. - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - const FileDescriptor rfd(fds[0]); - FileDescriptor wfd(fds[1]); - - // Fill with some random data. - std::vector<char> buf(kPageSize); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(kPageSize)); - - const FileDescriptor zero = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); - - // Close the write end to prevent blocking below. - wfd.reset(); - - // Splice to /dev/zero. The first call should empty the pipe, and the return - // value should not exceed the number of bytes available for reading. - EXPECT_THAT( - splice(rfd.get(), nullptr, zero.get(), nullptr, kPageSize + 123, 0), - SyscallSucceedsWithValue(kPageSize)); - EXPECT_THAT(splice(rfd.get(), nullptr, zero.get(), nullptr, 1, 0), - SyscallSucceedsWithValue(0)); -} - static volatile int signaled = 0; void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; } diff --git a/test/util/BUILD b/test/util/BUILD index b92af1c27..2dcf71613 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_library", "cc_test", "coreutil", "default_net_util", "gbenchmark", "gtest", "select_system") +load("//tools:defs.bzl", "cc_library", "cc_test", "coreutil", "default_net_util", "gbenchmark_internal", "gtest", "select_system") package( default_visibility = ["//:sandbox"], @@ -295,7 +295,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", gtest, - gbenchmark, + gbenchmark_internal, ], ) diff --git a/tools/bazeldefs/cc.bzl b/tools/bazeldefs/cc.bzl index 2831eac5f..57d33726a 100644 --- a/tools/bazeldefs/cc.bzl +++ b/tools/bazeldefs/cc.bzl @@ -9,6 +9,7 @@ cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" gtest = "@com_google_googletest//:gtest" gbenchmark = "@com_google_benchmark//:benchmark" +gbenchmark_internal = "@com_google_benchmark//:benchmark" grpcpp = "@com_github_grpc_grpc//:grpc++" vdso_linker_option = "-fuse-ld=gold " diff --git a/tools/defs.bzl b/tools/defs.bzl index 27542a2f5..f4266e1de 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -9,7 +9,7 @@ load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") load("//tools/nogo:defs.bzl", "nogo_test") load("//tools/bazeldefs:defs.bzl", _arch_genrule = "arch_genrule", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _more_shards = "more_shards", _most_shards = "most_shards", _proto_library = "proto_library", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _version = "version") -load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option") +load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _gbenchmark_internal = "gbenchmark_internal", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option") load("//tools/bazeldefs:go.bzl", _bazel_worker_proto = "bazel_worker_proto", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_rule = "go_rule", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos") load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") @@ -37,6 +37,7 @@ cc_library = _cc_library cc_test = _cc_test cc_toolchain = _cc_toolchain gbenchmark = _gbenchmark +gbenchmark_internal = _gbenchmark_internal gtest = _gtest grpcpp = _grpcpp vdso_linker_option = _vdso_linker_option diff --git a/website/_config.yml b/website/_config.yml index dc44945bc..5f1cbbdeb 100644 --- a/website/_config.yml +++ b/website/_config.yml @@ -44,3 +44,6 @@ authors: mpratt: name: Michael Pratt email: mpratt@google.com + nybidari: + name: Nayana Bidari + email: nybidari@google.com diff --git a/website/assets/images/2021-08-31-rack-figure1.png b/website/assets/images/2021-08-31-rack-figure1.png Binary files differnew file mode 100644 index 000000000..6d9fdb147 --- /dev/null +++ b/website/assets/images/2021-08-31-rack-figure1.png diff --git a/website/assets/images/2021-08-31-rack-figure2.png b/website/assets/images/2021-08-31-rack-figure2.png Binary files differnew file mode 100644 index 000000000..c2043ecae --- /dev/null +++ b/website/assets/images/2021-08-31-rack-figure2.png diff --git a/website/assets/images/2021-08-31-rack-figure3.png b/website/assets/images/2021-08-31-rack-figure3.png Binary files differnew file mode 100644 index 000000000..e8b689f33 --- /dev/null +++ b/website/assets/images/2021-08-31-rack-figure3.png diff --git a/website/blog/2021-08-31-gvisor-rack.md b/website/blog/2021-08-31-gvisor-rack.md new file mode 100644 index 000000000..e7d4582e4 --- /dev/null +++ b/website/blog/2021-08-31-gvisor-rack.md @@ -0,0 +1,120 @@ +# gVisor RACK + +gVisor has implemented the [RACK](https://datatracker.ietf.org/doc/html/rfc8985) +(Recent ACKnowledgement) TCP loss-detection algorithm in our network stack, +which improves throughput in the presence of packet loss and reordering. + +TCP is a connection-oriented protocol that detects and recovers from loss by +retransmitting packets. [RACK](https://datatracker.ietf.org/doc/html/rfc8985) is +one of the recent loss-detection methods implemented in Linux and BSD, which +helps in identifying packet loss quickly and accurately in the presence of +packet reordering and tail losses. + +## Background + +The TCP congestion window indicates the number of unacknowledged packets that +can be sent at any time. When packet loss is identified, the congestion window +is reduced depending on the type of loss. The sender will recover from the loss +after all the packets sent before reducing the congestion window are +acknowledged. If the loss is identified falsely by the connection, then the +connection enters loss recovery unnecessarily, resulting in sending fewer +packets. + +Packet loss is identified mainly in two ways: + +1. Three duplicate acknowledgments, which will result in either + [Fast](https://datatracker.ietf.org/doc/html/rfc2001#section-4) or + [SACK](https://datatracker.ietf.org/doc/html/rfc6675) recovery. The + congestion window is reduced depending on the type of congestion control + algorithm. For example, in the + [Reno](https://en.wikipedia.org/wiki/TCP_congestion_control#TCP_Tahoe_and_Reno) + algorithm it is reduced to half. +2. RTO (Retransmission Timeout) which will result in Timeout recovery. The + congestion window is reduced to one + [MSS](https://en.wikipedia.org/wiki/Maximum_segment_size). + +Both of these cases result in reducing the congestion window, with RTO being +more expensive. Most of the existing algorithms do not detect packet reordering, +which get incorrectly identified as packet loss, resulting in an RTO. +Furthermore, the loss of an ACK at the end of a sequence (known as "tail loss") +will also trigger RTO and slow down future transmissions unnecessarily. RACK +helps us to identify loss accurately in all these scenarios, and will avoid +entering RTO. + +## Implementation of RACK + +Implementation of RACK requires support for: + +1. Per-packet transmission timestamps: RACK detects loss depending on the + transmission times of the packet and the timestamp at which ACK was + received. +2. SACK and ability to detect DSACK: Selective Acknowledgement and Duplicate + SACK are used to adjust the timer window after which a packet can be marked + as lost. + +### Packet Reordering + +Packet reordering commonly occurs when different packets take different paths +through a network. The diagram below shows the transmission of four packets +which get reordered in transmission, and the resulting TCP behavior with and +without RACK. + +![Figure 1](/assets/images/2021-08-31-rack-figure1.png "Packet reordering.") + +In the above example, the sender sees three duplicate acknowledgments. Without +RACK, this is identified falsely as packet loss, and the congestion window will +be reduced after entering Fast/SACK recovery. + +To detect packet reordering, RACK uses a reorder window, bounded between +[[RTT](https://en.wikipedia.org/wiki/Round-trip_delay)/4, RTT]. The reorder +timer is set to expire after _RTT+reorder\_window_. A packet is marked as lost +when the packets following it were acknowledged using SACK and the reorder timer +expires. The reorder window is increased when a DSACK is received (which +indicates that there is a higher degree of reordering). + +### Tail Loss + +Tail loss occurs when the packets are lost at the end of data transmission. The +diagram below shows an example of tail loss when the last three packets are +lost, and how it is handled with and without RACK. + +![Figure 2](/assets/images/2021-08-31-rack-figure2.png "Tail loss figure 2.") + +For tail losses, RACK uses a Tail Loss Probe (TLP), which relies on a timer for +the last packet sent. The TLP timer is set to _2 \* RTT,_ after which a probe is +sent. The probe packet will allow the connection one more chance to detect a +loss by triggering ACK feedback to avoid entering RTO. In the above example, the +loss is recovered without entering the RTO. + +TLP will also help in cases where the ACK was lost but all the packets were +received by the receiver. The below diagram shows that the ACK received for the +probe packet avoided the RTO. + +![Figure 3](/assets/images/2021-08-31-rack-figure3.png "Tail loss figure 3.") + +If there was some loss, then the ACK for the probe packet will have the SACK +blocks, which will be used to detect and retransmit the lost packets. + +In gVisor, we have support for +[NewReno](https://datatracker.ietf.org/doc/html/rfc6582) and SACK loss recovery +methods. We +[added support for RACK](https://github.com/google/gvisor/issues/5243) recently, +and it is the default when SACK is enabled. After enabling RACK, our internal +benchmarks in the presence of reordering and tail losses and the data we took +from internal users inside Google have shown ~50% reduction in the number of +RTOs. + +While RACK has improved one aspect of TCP performance by reducing the timeouts +in the presence of reordering and tail losses, in gVisor we plan to implement +the undoing of congestion windows and +[BBRv2](https://datatracker.ietf.org/doc/html/draft-cardwell-iccrg-bbr-congestion-control) +(once there is an RFC available) to further improve TCP performance in less +ideal network conditions. + +If you haven’t already, try gVisor. The instructions to get started are in our +[Quick Start](https://gvisor.dev/docs/user_guide/quick_start/docker/). You can +also get involved with the gVisor community via our +[Gitter channel](https://gitter.im/gvisor/community), +[email list](https://groups.google.com/forum/#!forum/gvisor-users), +[issue tracker](https://gvisor.dev/issue/new), and +[Github repository](https://github.com/google/gvisor). diff --git a/website/blog/BUILD b/website/blog/BUILD index 17beb721f..0384b9ba9 100644 --- a/website/blog/BUILD +++ b/website/blog/BUILD @@ -49,6 +49,16 @@ doc( permalink = "/blog/2020/10/22/platform-portability/", ) +doc( + name = "gvisor-rack", + src = "2021-08-31-gvisor-rack.md", + authors = [ + "nybidari", + ], + layout = "post", + permalink = "/blog/2021/08/31/gvisor-rack/", +) + docs( name = "posts", deps = [ |