diff options
Diffstat (limited to 'pkg')
499 files changed, 20273 insertions, 6240 deletions
diff --git a/pkg/abi/linux/errno/errno.go b/pkg/abi/linux/errno/errno.go index 5a09c6605..38ebbb1d7 100644 --- a/pkg/abi/linux/errno/errno.go +++ b/pkg/abi/linux/errno/errno.go @@ -157,9 +157,32 @@ const ( EHWPOISON ) -// errnos derived from other errnos +// errnos derived from other errnos. const ( EWOULDBLOCK = EAGAIN EDEADLOCK = EDEADLK ENONET = ENOENT ) + +// errnos for internal errors. +const ( + // ERESTARTSYS is returned by an interrupted syscall to indicate that it + // should be converted to EINTR if interrupted by a signal delivered to a + // user handler without SA_RESTART set, and restarted otherwise. + ERESTARTSYS = 512 + + // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it + // should always be restarted. + ERESTARTNOINTR = 513 + + // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it + // should be converted to EINTR if interrupted by a signal delivered to a + // user handler, and restarted otherwise. + ERESTARTNOHAND = 514 + + // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate + // that it should be restarted using a custom function. The interrupted + // syscall must register a custom restart function by calling + // Task.SetRestartSyscallFn. + ERESTART_RESTARTBLOCK = 516 +) diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 1e23850a9..67646f837 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -242,7 +242,7 @@ const ( // Statx represents struct statx. // -// +marshal +// +marshal slice:StatxSlice type Statx struct { Mask uint32 Blksize uint32 @@ -270,6 +270,8 @@ type Statx struct { var SizeOfStatx = (*Statx)(nil).SizeBytes() // FileMode represents a mode_t. +// +// +marshal type FileMode uint16 // Permissions returns just the permission bits. diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 006b5a525..29062c97a 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -170,3 +170,18 @@ const ( KCOV_MODE_TRACE_PC = 2 KCOV_MODE_TRACE_CMP = 3 ) + +// Attestation ioctls. +var ( + SIGN_ATTESTATION_REPORT = IOC(_IOC_READ, 's', 1, 65) +) + +// SizeOfQuoteInputData is the number of bytes in the input data of ioctl call +// to get quote. +const SizeOfQuoteInputData = 64 + +// SignReport is a struct that gets signed quote from input data. +type SignReport struct { + data [64]byte + quote []byte +} diff --git a/pkg/abi/linux/msgqueue.go b/pkg/abi/linux/msgqueue.go index e1e8d0357..0612a8214 100644 --- a/pkg/abi/linux/msgqueue.go +++ b/pkg/abi/linux/msgqueue.go @@ -47,7 +47,7 @@ const ( MSGSSZ = 16 // MSGSEG is simplified due to the inexistance of a ternary operator. - MSGSEG = (MSGPOOL * 1024) / MSGSSZ + MSGSEG = 0xffff ) // MsqidDS is equivelant to struct msqid64_ds. Source: diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 95871b8a5..f60e42997 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -542,6 +542,15 @@ type ControlMessageIPPacketInfo struct { DestinationAddr InetAddr } +// ControlMessageIPv6PacketInfo represents struct in6_pktinfo from linux/ipv6.h. +// +// +marshal +// +stateify savable +type ControlMessageIPv6PacketInfo struct { + Addr Inet6Addr + NIC uint32 +} + // SizeOfControlMessageCredentials is the binary size of a // ControlMessageCredentials struct. var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes() @@ -566,6 +575,10 @@ const SizeOfControlMessageTClass = 4 // control message. const SizeOfControlMessageIPPacketInfo = 12 +// SizeOfControlMessageIPv6PacketInfo is the size of a +// ControlMessageIPv6PacketInfo. +const SizeOfControlMessageIPv6PacketInfo = 20 + // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index bd3a5cce9..6d8b5f818 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -8,7 +8,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/context", - "//pkg/syserror", + "//pkg/errors/linuxerr", ], ) diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go index d7acc1d9f..985199cfa 100644 --- a/pkg/amutex/amutex.go +++ b/pkg/amutex/amutex.go @@ -20,7 +20,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/errors/linuxerr" ) // Sleeper must be implemented by users of the abortable mutex to allow for @@ -33,7 +33,7 @@ type NoopSleeper = context.Context // Block blocks until either receiving from ch succeeds (in which case it // returns nil) or sleeper is interrupted (in which case it returns -// syserror.ErrInterrupted). +// linuxerr.ErrInterrupted). func Block(sleeper Sleeper, ch <-chan struct{}) error { cancel := sleeper.SleepStart() select { @@ -42,7 +42,7 @@ func Block(sleeper Sleeper, ch <-chan struct{}) error { return nil case <-cancel: sleeper.SleepFinish(false) - return syserror.ErrInterrupted + return linuxerr.ErrInterrupted } } diff --git a/pkg/atomicbitops/aligned_32bit_unsafe.go b/pkg/atomicbitops/aligned_32bit_unsafe.go index a17d317cc..0e4765c48 100644 --- a/pkg/atomicbitops/aligned_32bit_unsafe.go +++ b/pkg/atomicbitops/aligned_32bit_unsafe.go @@ -39,9 +39,9 @@ type AlignedAtomicInt64 struct { } func (aa *AlignedAtomicInt64) ptr() *int64 { - // On 32-bit systems, aa.value is is guaranteed to be 32-bit aligned. - // It means that in the 12-byte aa.value, there are guaranteed to be 8 - // contiguous bytes with 64-bit alignment. + // On 32-bit systems, aa.value is guaranteed to be 32-bit aligned. It means + // that in the 12-byte aa.value, there are guaranteed to be 8 contiguous bytes + // with 64-bit alignment. return (*int64)(unsafe.Pointer((uintptr(unsafe.Pointer(&aa.value)) + 4) &^ 7)) } @@ -77,9 +77,9 @@ type AlignedAtomicUint64 struct { } func (aa *AlignedAtomicUint64) ptr() *uint64 { - // On 32-bit systems, aa.value is is guaranteed to be 32-bit aligned. - // It means that in the 12-byte aa.value, there are guaranteed to be 8 - // contiguous bytes with 64-bit alignment. + // On 32-bit systems, aa.value is guaranteed to be 32-bit aligned. It means + // that in the 12-byte aa.value, there are guaranteed to be 8 contiguous bytes + // with 64-bit alignment. return (*uint64)(unsafe.Pointer((uintptr(unsafe.Pointer(&aa.value)) + 4) &^ 7)) } diff --git a/pkg/atomicbitops/atomicbitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s index 54c887ee5..6b9a67adc 100644 --- a/pkg/atomicbitops/atomicbitops_amd64.s +++ b/pkg/atomicbitops/atomicbitops_amd64.s @@ -16,28 +16,28 @@ #include "textflag.h" -TEXT ·AndUint32(SB),$0-12 - MOVQ addr+0(FP), BP +TEXT ·AndUint32(SB),NOSPLIT,$0-12 + MOVQ addr+0(FP), BX MOVL val+8(FP), AX LOCK - ANDL AX, 0(BP) + ANDL AX, 0(BX) RET -TEXT ·OrUint32(SB),$0-12 - MOVQ addr+0(FP), BP +TEXT ·OrUint32(SB),NOSPLIT,$0-12 + MOVQ addr+0(FP), BX MOVL val+8(FP), AX LOCK - ORL AX, 0(BP) + ORL AX, 0(BX) RET -TEXT ·XorUint32(SB),$0-12 - MOVQ addr+0(FP), BP +TEXT ·XorUint32(SB),NOSPLIT,$0-12 + MOVQ addr+0(FP), BX MOVL val+8(FP), AX LOCK - XORL AX, 0(BP) + XORL AX, 0(BX) RET -TEXT ·CompareAndSwapUint32(SB),$0-20 +TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20 MOVQ addr+0(FP), DI MOVL old+8(FP), AX MOVL new+12(FP), DX @@ -46,28 +46,28 @@ TEXT ·CompareAndSwapUint32(SB),$0-20 MOVL AX, ret+16(FP) RET -TEXT ·AndUint64(SB),$0-16 - MOVQ addr+0(FP), BP +TEXT ·AndUint64(SB),NOSPLIT,$0-16 + MOVQ addr+0(FP), BX MOVQ val+8(FP), AX LOCK - ANDQ AX, 0(BP) + ANDQ AX, 0(BX) RET -TEXT ·OrUint64(SB),$0-16 - MOVQ addr+0(FP), BP +TEXT ·OrUint64(SB),NOSPLIT,$0-16 + MOVQ addr+0(FP), BX MOVQ val+8(FP), AX LOCK - ORQ AX, 0(BP) + ORQ AX, 0(BX) RET -TEXT ·XorUint64(SB),$0-16 - MOVQ addr+0(FP), BP +TEXT ·XorUint64(SB),NOSPLIT,$0-16 + MOVQ addr+0(FP), BX MOVQ val+8(FP), AX LOCK - XORQ AX, 0(BP) + XORQ AX, 0(BX) RET -TEXT ·CompareAndSwapUint64(SB),$0-32 +TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32 MOVQ addr+0(FP), DI MOVQ old+8(FP), AX MOVQ new+16(FP), DX diff --git a/pkg/atomicbitops/atomicbitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s index 5c780851b..644a6bca5 100644 --- a/pkg/atomicbitops/atomicbitops_arm64.s +++ b/pkg/atomicbitops/atomicbitops_arm64.s @@ -16,7 +16,7 @@ #include "textflag.h" -TEXT ·AndUint32(SB),$0-12 +TEXT ·AndUint32(SB),NOSPLIT,$0-12 MOVD ptr+0(FP), R0 MOVW val+8(FP), R1 again: @@ -26,7 +26,7 @@ again: CBNZ R3, again RET -TEXT ·OrUint32(SB),$0-12 +TEXT ·OrUint32(SB),NOSPLIT,$0-12 MOVD ptr+0(FP), R0 MOVW val+8(FP), R1 again: @@ -36,7 +36,7 @@ again: CBNZ R3, again RET -TEXT ·XorUint32(SB),$0-12 +TEXT ·XorUint32(SB),NOSPLIT,$0-12 MOVD ptr+0(FP), R0 MOVW val+8(FP), R1 again: @@ -46,7 +46,7 @@ again: CBNZ R3, again RET -TEXT ·CompareAndSwapUint32(SB),$0-20 +TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20 MOVD addr+0(FP), R0 MOVW old+8(FP), R1 MOVW new+12(FP), R2 @@ -60,7 +60,7 @@ done: MOVW R3, prev+16(FP) RET -TEXT ·AndUint64(SB),$0-16 +TEXT ·AndUint64(SB),NOSPLIT,$0-16 MOVD ptr+0(FP), R0 MOVD val+8(FP), R1 again: @@ -70,7 +70,7 @@ again: CBNZ R3, again RET -TEXT ·OrUint64(SB),$0-16 +TEXT ·OrUint64(SB),NOSPLIT,$0-16 MOVD ptr+0(FP), R0 MOVD val+8(FP), R1 again: @@ -80,7 +80,7 @@ again: CBNZ R3, again RET -TEXT ·XorUint64(SB),$0-16 +TEXT ·XorUint64(SB),NOSPLIT,$0-16 MOVD ptr+0(FP), R0 MOVD val+8(FP), R1 again: @@ -90,7 +90,7 @@ again: CBNZ R3, again RET -TEXT ·CompareAndSwapUint64(SB),$0-32 +TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32 MOVD addr+0(FP), R0 MOVD old+8(FP), R1 MOVD new+16(FP), R2 diff --git a/pkg/atomicbitops/atomicbitops_noasm.go b/pkg/atomicbitops/atomicbitops_noasm.go index 474c0c815..af6b1362e 100644 --- a/pkg/atomicbitops/atomicbitops_noasm.go +++ b/pkg/atomicbitops/atomicbitops_noasm.go @@ -21,6 +21,7 @@ import ( "sync/atomic" ) +//go:nosplit func AndUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -31,6 +32,7 @@ func AndUint32(addr *uint32, val uint32) { } } +//go:nosplit func OrUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -41,6 +43,7 @@ func OrUint32(addr *uint32, val uint32) { } } +//go:nosplit func XorUint32(addr *uint32, val uint32) { for { o := atomic.LoadUint32(addr) @@ -51,6 +54,7 @@ func XorUint32(addr *uint32, val uint32) { } } +//go:nosplit func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) { for { prev = atomic.LoadUint32(addr) @@ -63,6 +67,7 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) { } } +//go:nosplit func AndUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -73,6 +78,7 @@ func AndUint64(addr *uint64, val uint64) { } } +//go:nosplit func OrUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -83,6 +89,7 @@ func OrUint64(addr *uint64, val uint64) { } } +//go:nosplit func XorUint64(addr *uint64, val uint64) { for { o := atomic.LoadUint64(addr) @@ -93,6 +100,7 @@ func XorUint64(addr *uint64, val uint64) { } } +//go:nosplit func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) { for { prev = atomic.LoadUint64(addr) diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go index 7bcfcd543..a4610f977 100644 --- a/pkg/buffer/view.go +++ b/pkg/buffer/view.go @@ -378,6 +378,20 @@ func (v *View) Copy() (other View) { return } +// Clone makes a more shallow copy compared to Copy. The underlying payload +// slice (buffer.data) is shared but the buffers themselves are copied. +func (v *View) Clone() *View { + other := &View{ + size: v.size, + } + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + newBuf := other.pool.getNoInit() + *newBuf = *buf + other.data.PushBack(newBuf) + } + return other +} + // Apply applies the given function across all valid data. func (v *View) Apply(fn func([]byte)) { for buf := v.data.Front(); buf != nil; buf = buf.Next() { diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 796efa240..59784eacb 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -509,6 +509,24 @@ func TestView(t *testing.T) { } } +func TestViewClone(t *testing.T) { + const ( + originalSize = 90 + bytesToDelete = 30 + ) + var v View + v.AppendOwned(bytes.Repeat([]byte{originalSize}, originalSize)) + + clonedV := v.Clone() + v.TrimFront(bytesToDelete) + if got, want := int(v.Size()), originalSize-bytesToDelete; got != want { + t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) + } + if got := clonedV.Size(); got != originalSize { + t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) + } +} + func TestViewPullUp(t *testing.T) { for _, tc := range []struct { desc string diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index 69eeb7528..4d5e062a8 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -37,6 +37,14 @@ package cpuid // arch/arm64/include/uapi/asm/hwcap.h type Feature int +// HostFeatureSet returns a FeatureSet that matches that of the host machine. +// Callers must not mutate the returned FeatureSet. +func HostFeatureSet() *FeatureSet { + return hostFeatureSet +} + +var hostFeatureSet = getHostFeatureSet() + // ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a // subset of the host feature set. type ErrIncompatible struct { diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go index 6e61d562f..04645aed5 100644 --- a/pkg/cpuid/cpuid_arm64.go +++ b/pkg/cpuid/cpuid_arm64.go @@ -230,6 +230,16 @@ type FeatureSet struct { CPURevision uint8 } +// Clone returns a copy of fs. +func (fs *FeatureSet) Clone() *FeatureSet { + fs2 := *fs + fs2.Set = make(map[Feature]bool) + for f, b := range fs.Set { + fs2.Set[f] = b + } + return &fs2 +} + // CheckHostCompatible returns nil if fs is a subset of the host feature set. // Noop on arm64. func (fs *FeatureSet) CheckHostCompatible() error { @@ -292,9 +302,9 @@ func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) { fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline. } -// HostFeatureSet uses hwCap to get host values and construct a feature set +// getHostFeatureSet uses hwCap to get host values and construct a feature set // that matches that of the host machine. -func HostFeatureSet() *FeatureSet { +func getHostFeatureSet() *FeatureSet { s := make(map[Feature]bool) for f := range arm64FeatureStrings { diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go index dc17cade8..a92d32d74 100644 --- a/pkg/cpuid/cpuid_x86.go +++ b/pkg/cpuid/cpuid_x86.go @@ -627,6 +627,17 @@ type FeatureSet struct { CacheLine uint32 } +// Clone returns a copy of fs. +func (fs *FeatureSet) Clone() *FeatureSet { + fs2 := *fs + fs2.Set = make(map[Feature]bool) + for f, b := range fs.Set { + fs2.Set[f] = b + } + fs2.Caches = append([]Cache(nil), fs.Caches...) + return &fs2 +} + // FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is // equivalent to the "flags" field in /proc/cpuinfo. func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string { @@ -961,13 +972,13 @@ func (fs *FeatureSet) UseXsaveopt() bool { // HostID executes a native CPUID instruction. func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32) -// HostFeatureSet uses cpuid to get host values and construct a feature set +// getHostFeatureSet uses cpuid to get host values and construct a feature set // that matches that of the host machine. Note that there are several places // where there appear to be some unnecessary assignments between register names // (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show // where the different feature blocks come from, to make the code easier to // inspect and read. -func HostFeatureSet() *FeatureSet { +func getHostFeatureSet() *FeatureSet { // eax=0 gets max supported feature and vendor ID. _, bx, cx, dx := HostID(0, 0) vendorID := vendorIDFromRegs(bx, cx, dx) diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go index 69e867386..28eba2ff6 100644 --- a/pkg/crypto/crypto_stdlib.go +++ b/pkg/crypto/crypto_stdlib.go @@ -19,14 +19,21 @@ package crypto import ( "crypto/ecdsa" + "crypto/elliptic" "crypto/sha512" + "fmt" "math/big" ) -// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the -// public key, pub. Its return value records whether the signature is valid. -func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) { - return ecdsa.Verify(pub, hash, r, s), nil +// EcdsaP384Sha384Verify verifies the signature in r, s of hash using ECDSA +// P384 + SHA 384 and the public key, pub. Its return value records whether +// the signature is valid. +func EcdsaP384Sha384Verify(pub *ecdsa.PublicKey, data []byte, r, s *big.Int) (bool, error) { + if pub.Curve != elliptic.P384() { + return false, fmt.Errorf("unsupported key curve: want P-384, got %v", pub.Curve) + } + digest := sha512.Sum384(data) + return ecdsa.Verify(pub, digest[:], r, s), nil } // SumSha384 returns the SHA384 checksum of the data. diff --git a/pkg/errors/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD index 201727780..e73b0e28a 100644 --- a/pkg/errors/linuxerr/BUILD +++ b/pkg/errors/linuxerr/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "linuxerr", - srcs = ["linuxerr.go"], + srcs = [ + "internal.go", + "linuxerr.go", + ], visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux/errno", @@ -20,7 +23,6 @@ go_test( ":linuxerr", "//pkg/abi/linux/errno", "//pkg/errors", - "//pkg/syserror", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/errors/linuxerr/internal.go b/pkg/errors/linuxerr/internal.go new file mode 100644 index 000000000..127bba0df --- /dev/null +++ b/pkg/errors/linuxerr/internal.go @@ -0,0 +1,120 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"),; +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linuxerr + +import ( + "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors" +) + +var ( + // ErrWouldBlock is an internal error used to indicate that an operation + // cannot be satisfied immediately, and should be retried at a later + // time, possibly when the caller has received a notification that the + // operation may be able to complete. It is used by implementations of + // the kio.File interface. + ErrWouldBlock = errors.New(errno.EWOULDBLOCK, "request would block") + + // ErrInterrupted is returned if a request is interrupted before it can + // complete. + ErrInterrupted = errors.New(errno.EINTR, "request was interrupted") + + // ErrExceedsFileSizeLimit is returned if a request would exceed the + // file's size limit. + ErrExceedsFileSizeLimit = errors.New(errno.E2BIG, "exceeds file size limit") +) + +var errorMap = map[error]*errors.Error{ + ErrWouldBlock: EWOULDBLOCK, + ErrInterrupted: EINTR, + ErrExceedsFileSizeLimit: EFBIG, +} + +// errorUnwrappers is an array of unwrap functions to extract typed errors. +var errorUnwrappers = []func(error) (*errors.Error, bool){} + +// AddErrorUnwrapper registers an unwrap method that can extract a concrete error +// from a typed, but not initialized, error. +func AddErrorUnwrapper(unwrap func(e error) (*errors.Error, bool)) { + errorUnwrappers = append(errorUnwrappers, unwrap) +} + +// TranslateError translates errors to errnos, it will return false if +// the error was not registered. +func TranslateError(from error) (*errors.Error, bool) { + if err, ok := errorMap[from]; ok { + return err, true + } + // Try to unwrap the error if we couldn't match an error + // exactly. This might mean that a package has its own + // error type. + for _, unwrap := range errorUnwrappers { + if err, ok := unwrap(from); ok { + return err, true + } + } + return nil, false +} + +// These errors are significant because ptrace syscall exit tracing can +// observe them. +// +// For all of the following errors, if the syscall is not interrupted by a +// signal delivered to a user handler, the syscall is restarted. +var ( + // ERESTARTSYS is returned by an interrupted syscall to indicate that it + // should be converted to EINTR if interrupted by a signal delivered to a + // user handler without SA_RESTART set, and restarted otherwise. + ERESTARTSYS = errors.New(errno.ERESTARTSYS, "to be restarted if SA_RESTART is set") + + // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it + // should always be restarted. + ERESTARTNOINTR = errors.New(errno.ERESTARTNOINTR, "to be restarted") + + // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it + // should be converted to EINTR if interrupted by a signal delivered to a + // user handler, and restarted otherwise. + ERESTARTNOHAND = errors.New(errno.ERESTARTNOHAND, "to be restarted if no handler") + + // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate + // that it should be restarted using a custom function. The interrupted + // syscall must register a custom restart function by calling + // Task.SetRestartSyscallFn. + ERESTART_RESTARTBLOCK = errors.New(errno.ERESTART_RESTARTBLOCK, "interrupted by signal") +) + +var restartMap = map[int]*errors.Error{ + -int(errno.ERESTARTSYS): ERESTARTSYS, + -int(errno.ERESTARTNOINTR): ERESTARTNOINTR, + -int(errno.ERESTARTNOHAND): ERESTARTNOHAND, + -int(errno.ERESTART_RESTARTBLOCK): ERESTART_RESTARTBLOCK, +} + +// IsRestartError checks if a given error is a restart error. +func IsRestartError(err error) bool { + switch err { + case ERESTARTSYS, ERESTARTNOINTR, ERESTARTNOHAND, ERESTART_RESTARTBLOCK: + return true + default: + return false + } +} + +// SyscallRestartErrorFromReturn returns the SyscallRestartErrno represented by +// rv, the value in a syscall return register. +func SyscallRestartErrorFromReturn(rv uintptr) (*errors.Error, bool) { + err, ok := restartMap[int(rv)] + return err, ok +} diff --git a/pkg/errors/linuxerr/linuxerr.go b/pkg/errors/linuxerr/linuxerr.go index f9f8412e0..5905ef593 100644 --- a/pkg/errors/linuxerr/linuxerr.go +++ b/pkg/errors/linuxerr/linuxerr.go @@ -27,6 +27,12 @@ import ( const maxErrno uint32 = errno.EHWPOISON + 1 +// The following errors are semantically identical to Errno of type unix.Errno +// or sycall.Errno. However, since the type are distinct ( these are +// *errors.Error), they are not directly comperable. However, the Errno method +// returns an Errno number such that the error can be compared to unix/syscall.Errno +// (e.g. unix.Errno(EPERM.Errno()) == unix.EPERM is true). Converting unix/syscall.Errno +// to the errors should be done via the lookup methods provided. var ( NOERROR = errors.New(errno.NOERRNO, "not an error") EPERM = errors.New(errno.EPERM, "operation not permitted") @@ -177,7 +183,7 @@ var ( var errNotValidError = errors.New(errno.Errno(maxErrno), "not a valid error") // The following errorSlice holds errors by errno for fast translation between -// errnos (especially uint32(sycall.Errno)) and *Error. +// errnos (especially uint32(sycall.Errno)) and *errors.Error. var errorSlice = []*errors.Error{ // Errno values from include/uapi/asm-generic/errno-base.h. errno.NOERRNO: NOERROR, diff --git a/pkg/errors/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go index f09d61b02..df7cd1c5a 100644 --- a/pkg/errors/linuxerr/linuxerr_test.go +++ b/pkg/errors/linuxerr/linuxerr_test.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package syserror_test +package linuxerr_test import ( "errors" + "fmt" "io" "io/fs" "syscall" @@ -25,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux/errno" gErrors "gvisor.dev/gvisor/pkg/errors" "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/syserror" ) var globalError error @@ -42,12 +42,6 @@ func BenchmarkAssignLinuxerr(b *testing.B) { } } -func BenchmarkAssignSyserror(b *testing.B) { - for i := b.N; i > 0; i-- { - globalError = linuxerr.ENOMSG - } -} - func BenchmarkCompareUnix(b *testing.B) { globalError = unix.EAGAIN j := 0 @@ -68,16 +62,6 @@ func BenchmarkCompareLinuxerr(b *testing.B) { } } -func BenchmarkCompareSyserror(b *testing.B) { - globalError = linuxerr.EAGAIN - j := 0 - for i := b.N; i > 0; i-- { - if globalError == linuxerr.EACCES { - j++ - } - } -} - func BenchmarkSwitchUnix(b *testing.B) { globalError = unix.EPERM j := 0 @@ -108,21 +92,6 @@ func BenchmarkSwitchLinuxerr(b *testing.B) { } } -func BenchmarkSwitchSyserror(b *testing.B) { - globalError = linuxerr.EPERM - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case linuxerr.EACCES: - j++ - case syserror.EINTR: - j += 2 - case linuxerr.EAGAIN: - j += 3 - } - } -} - func BenchmarkReturnUnix(b *testing.B) { var localError error f := func() error { @@ -170,47 +139,40 @@ func BenchmarkConvertUnixLinuxerrZero(b *testing.B) { } type translationTestTable struct { - fn string errIn error - syscallErrorIn unix.Errno expectedBool bool - expectedTranslation unix.Errno + expectedTranslation *gErrors.Error } func TestErrorTranslation(t *testing.T) { - myError := errors.New("My test error") - myError2 := errors.New("Another test error") testTable := []translationTestTable{ - {"TranslateError", myError, 0, false, 0}, - {"TranslateError", myError2, 0, false, 0}, - {"AddErrorTranslation", myError, unix.EAGAIN, true, 0}, - {"AddErrorTranslation", myError, unix.EAGAIN, false, 0}, - {"AddErrorTranslation", myError, unix.EPERM, false, 0}, - {"TranslateError", myError, 0, true, unix.EAGAIN}, - {"TranslateError", myError2, 0, false, 0}, - {"AddErrorTranslation", myError2, unix.EPERM, true, 0}, - {"AddErrorTranslation", myError2, unix.EPERM, false, 0}, - {"AddErrorTranslation", myError2, unix.EAGAIN, false, 0}, - {"TranslateError", myError, 0, true, unix.EAGAIN}, - {"TranslateError", myError2, 0, true, unix.EPERM}, + { + errIn: linuxerr.ENOENT, + }, + { + errIn: unix.ENOENT, + }, + { + errIn: linuxerr.ErrInterrupted, + expectedBool: true, + expectedTranslation: linuxerr.EINTR, + }, + { + errIn: linuxerr.ERESTART_RESTARTBLOCK, + }, + { + errIn: errors.New("some new error"), + }, } for _, tt := range testTable { - switch tt.fn { - case "TranslateError": - err, ok := syserror.TranslateError(tt.errIn) - if ok != tt.expectedBool { - t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) + t.Run(fmt.Sprintf("err: %v %T", tt.errIn, tt.errIn), func(t *testing.T) { + err, ok := linuxerr.TranslateError(tt.errIn) + if (!tt.expectedBool && err != nil) || (tt.expectedBool != ok) { + t.Fatalf("%v => %v %v expected %v err: nil", tt.errIn, err, ok, tt.expectedBool) } else if err != tt.expectedTranslation { - t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation) - } - case "AddErrorTranslation": - ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn) - if ok != tt.expectedBool { - t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) + t.Fatalf("%v => %v expected %v", tt.errIn, err, tt.expectedTranslation) } - default: - t.Fatalf("Unknown function %v", tt.fn) - } + }) } } diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index ad15d3672..56399232a 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "event.go", "event_any.go", + "processor.go", "rate.go", ], visibility = ["//:sandbox"], diff --git a/pkg/eventchannel/event_any.go b/pkg/eventchannel/event_any.go index 13f300061..b708937a4 100644 --- a/pkg/eventchannel/event_any.go +++ b/pkg/eventchannel/event_any.go @@ -26,3 +26,8 @@ import ( func newAny(m proto.Message) (*anypb.Any, error) { return anypb.New(m) } + +func emptyAny() *anypb.Any { + var any anypb.Any + return &any +} diff --git a/pkg/eventchannel/processor.go b/pkg/eventchannel/processor.go new file mode 100644 index 000000000..e765c10d1 --- /dev/null +++ b/pkg/eventchannel/processor.go @@ -0,0 +1,130 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eventchannel + +import ( + "encoding/binary" + "fmt" + "io" + "os" + "time" + + "google.golang.org/protobuf/proto" + pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto" +) + +// eventProcessor carries display state across multiple events. +type eventProcessor struct { + filtering bool + // filtered is the number of events omitted since printing the last matching + // event. Only meaningful when filtering == true. + filtered uint64 + // allowlist is the set of event names to display. If empty, all events are + // displayed. + allowlist map[string]bool +} + +// newEventProcessor creates a new EventProcessor with filters. +func newEventProcessor(filters []string) *eventProcessor { + e := &eventProcessor{ + filtering: len(filters) > 0, + allowlist: make(map[string]bool), + } + for _, f := range filters { + e.allowlist[f] = true + } + return e +} + +// processOne reads, parses and displays a single event from the event channel. +// +// The event channel is a stream of (msglen, payload) packets; this function +// processes a single such packet. The msglen is a uvarint-encoded length for +// the associated payload. The payload is a binary-encoded 'Any' protobuf, which +// in turn encodes an arbitrary event protobuf. +func (e *eventProcessor) processOne(src io.Reader, out *os.File) error { + // Read and parse the msglen. + lenbuf := make([]byte, binary.MaxVarintLen64) + if _, err := io.ReadFull(src, lenbuf); err != nil { + return err + } + msglen, consumed := binary.Uvarint(lenbuf) + if consumed <= 0 { + return fmt.Errorf("couldn't parse the message length") + } + + // Read the payload. + buf := make([]byte, msglen) + // Copy any unused bytes from the len buffer into the payload buffer. These + // bytes are actually part of the payload. + extraBytes := copy(buf, lenbuf[consumed:]) + if _, err := io.ReadFull(src, buf[extraBytes:]); err != nil { + return err + } + + // Unmarshal the payload into an "Any" protobuf, which encodes the actual + // event. + encodedEv := emptyAny() + if err := proto.Unmarshal(buf, encodedEv); err != nil { + return fmt.Errorf("failed to unmarshal 'any' protobuf message: %v", err) + } + + var ev pb.DebugEvent + if err := (encodedEv).UnmarshalTo(&ev); err != nil { + return fmt.Errorf("failed to decode 'any' protobuf message: %v", err) + } + + if e.filtering && e.allowlist[ev.Name] { + e.filtered++ + return nil + } + + if e.filtering && e.filtered > 0 { + if e.filtered == 1 { + fmt.Fprintf(out, "... filtered %d event ...\n\n", e.filtered) + } else { + fmt.Fprintf(out, "... filtered %d events ...\n\n", e.filtered) + } + e.filtered = 0 + } + + // Extract the inner event and display it. Example: + // + // 2017-10-04 14:35:05.316180374 -0700 PDT m=+1.132485846 + // cloud_gvisor.MemoryUsage { + // total: 23822336 + // } + fmt.Fprintf(out, "%v\n%v {\n", time.Now(), ev.Name) + fmt.Fprintf(out, "%v", ev.Text) + fmt.Fprintf(out, "}\n\n") + + return nil +} + +// ProcessAll reads, parses and displays all events from src. The events are +// displayed to out. +func ProcessAll(src io.Reader, filters []string, out *os.File) error { + ep := newEventProcessor(filters) + for { + switch err := ep.processOne(src, out); err { + case nil: + continue + case io.EOF: + return nil + default: + return err + } + } +} diff --git a/pkg/eventfd/BUILD b/pkg/eventfd/BUILD new file mode 100644 index 000000000..02407cb99 --- /dev/null +++ b/pkg/eventfd/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "eventfd", + srcs = [ + "eventfd.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/hostarch", + "//pkg/tcpip/link/rawfile", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "eventfd_test", + srcs = ["eventfd_test.go"], + library = ":eventfd", +) diff --git a/pkg/eventfd/eventfd.go b/pkg/eventfd/eventfd.go new file mode 100644 index 000000000..acdac01b8 --- /dev/null +++ b/pkg/eventfd/eventfd.go @@ -0,0 +1,115 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package eventfd wraps Linux's eventfd(2) syscall. +package eventfd + +import ( + "fmt" + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" +) + +const sizeofUint64 = 8 + +// Eventfd represents a Linux eventfd object. +type Eventfd struct { + fd int +} + +// Create returns an initialized eventfd. +func Create() (Eventfd, error) { + fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0) + if err != 0 { + return Eventfd{}, fmt.Errorf("failed to create eventfd: %v", error(err)) + } + if err := unix.SetNonblock(int(fd), true); err != nil { + unix.Close(int(fd)) + return Eventfd{}, err + } + return Eventfd{int(fd)}, nil +} + +// Wrap returns an initialized Eventfd using the provided fd. +func Wrap(fd int) Eventfd { + return Eventfd{fd} +} + +// Close closes the eventfd, after which it should not be used. +func (ev Eventfd) Close() error { + return unix.Close(ev.fd) +} + +// Dup copies the eventfd, calling dup(2) on the underlying file descriptor. +func (ev Eventfd) Dup() (Eventfd, error) { + other, err := unix.Dup(ev.fd) + if err != nil { + return Eventfd{}, fmt.Errorf("failed to dup: %v", other) + } + return Eventfd{other}, nil +} + +// Notify alerts other users of the eventfd. Users can receive alerts by +// calling Wait or Read. +func (ev Eventfd) Notify() error { + return ev.Write(1) +} + +// Write writes a specific value to the eventfd. +func (ev Eventfd) Write(val uint64) error { + var buf [sizeofUint64]byte + hostarch.ByteOrder.PutUint64(buf[:], val) + for { + n, err := unix.Write(ev.fd, buf[:]) + if err == unix.EINTR { + continue + } + if n != sizeofUint64 { + panic(fmt.Sprintf("short write to eventfd: got %d bytes, wanted %d", n, sizeofUint64)) + } + return err + } +} + +// Wait blocks until eventfd is non-zero (i.e. someone calls Notify or Write). +func (ev Eventfd) Wait() error { + _, err := ev.Read() + return err +} + +// Read blocks until eventfd is non-zero (i.e. someone calls Notify or Write) +// and returns the value read. +func (ev Eventfd) Read() (uint64, error) { + var tmp [sizeofUint64]byte + n, err := rawfile.BlockingReadUntranslated(ev.fd, tmp[:]) + if err != 0 { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + if n != sizeofUint64 { + panic(fmt.Sprintf("short read from eventfd: got %d bytes, wanted %d", n, sizeofUint64)) + } + return hostarch.ByteOrder.Uint64(tmp[:]), nil +} + +// FD returns the underlying file descriptor. Use with care, as this breaks the +// Eventfd abstraction. +func (ev Eventfd) FD() int { + return ev.fd +} diff --git a/pkg/eventfd/eventfd_test.go b/pkg/eventfd/eventfd_test.go new file mode 100644 index 000000000..96998d530 --- /dev/null +++ b/pkg/eventfd/eventfd_test.go @@ -0,0 +1,75 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eventfd + +import ( + "testing" + "time" +) + +func TestReadWrite(t *testing.T) { + efd, err := Create() + if err != nil { + t.Fatalf("failed to Create(): %v", err) + } + defer efd.Close() + + // Make sure we can read actual values + const want = 343 + if err := efd.Write(want); err != nil { + t.Fatalf("failed to write value: %d", want) + } + + got, err := efd.Read() + if err != nil { + t.Fatalf("failed to read value: %v", err) + } + if got != want { + t.Fatalf("Read(): got %d, but wanted %d", got, want) + } +} + +func TestWait(t *testing.T) { + efd, err := Create() + if err != nil { + t.Fatalf("failed to Create(): %v", err) + } + defer efd.Close() + + // There's no way to test with certainty that Wait() blocks indefinitely, but + // as a best-effort we can wait a bit on it. + errCh := make(chan error) + go func() { + errCh <- efd.Wait() + }() + select { + case err := <-errCh: + t.Fatalf("Wait() returned without a call to Notify(): %v", err) + case <-time.After(500 * time.Millisecond): + } + + // Notify and check that Wait() returned. + if err := efd.Notify(); err != nil { + t.Fatalf("Notify() failed: %v", err) + } + select { + case err := <-errCh: + if err != nil { + t.Fatalf("Read() failed: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("Read() did not return after Notify()") + } +} diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index 5d2ee4018..99410628f 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -121,7 +121,7 @@ func (ep *Endpoint) ctrlWaitFirst() error { return ep.futexWaitUntilActive() } -func (ep *Endpoint) ctrlRoundTrip() error { +func (ep *Endpoint) ctrlRoundTrip(mayRetainP bool) error { if err := ep.enterFutexWait(); err != nil { return err } @@ -133,6 +133,9 @@ func (ep *Endpoint) ctrlRoundTrip() error { if err := ep.futexWakePeer(); err != nil { return err } + // Since we don't know if the peer Endpoint is in the same process as this + // one (in which case it may need our P to run), we allow our P to be + // retaken regardless of mayRetainP. return ep.futexWaitUntilActive() } diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index f0e4ff487..88588ba0e 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -223,6 +223,23 @@ func (ep *Endpoint) RecvFirst() (uint32, error) { // * If ep is a client Endpoint, ep.Connect() has previously been called and // returned nil. func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { + return ep.sendRecv(dataLen, false /* mayRetainP */) +} + +// SendRecvFast is equivalent to SendRecv, but may prevent the caller's runtime +// P from being released, in which case the calling goroutine continues to +// count against GOMAXPROCS while waiting for the peer Endpoint to return +// control to the caller. +// +// SendRecvFast is appropriate if the peer Endpoint is expected to consistently +// return control in a short amount of time (less than ~10ms). +// +// Preconditions: As for SendRecv. +func (ep *Endpoint) SendRecvFast(dataLen uint32) (uint32, error) { + return ep.sendRecv(dataLen, true /* mayRetainP */) +} + +func (ep *Endpoint) sendRecv(dataLen uint32, mayRetainP bool) (uint32, error) { if dataLen > ep.dataCap { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) } @@ -233,7 +250,7 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { // they can only shoot themselves in the foot. *ep.dataLen() = dataLen raceBecomeInactive() - if err := ep.ctrlRoundTrip(); err != nil { + if err := ep.ctrlRoundTrip(mayRetainP); err != nil { return 0, err } raceBecomeActive() diff --git a/pkg/lisafs/BUILD b/pkg/lisafs/BUILD new file mode 100644 index 000000000..313c1756d --- /dev/null +++ b/pkg/lisafs/BUILD @@ -0,0 +1,117 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_template_instance( + name = "control_fd_refs", + out = "control_fd_refs.go", + package = "lisafs", + prefix = "controlFD", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "ControlFD", + }, +) + +go_template_instance( + name = "open_fd_refs", + out = "open_fd_refs.go", + package = "lisafs", + prefix = "openFD", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "OpenFD", + }, +) + +go_template_instance( + name = "control_fd_list", + out = "control_fd_list.go", + package = "lisafs", + prefix = "controlFD", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*ControlFD", + "Linker": "*ControlFD", + }, +) + +go_template_instance( + name = "open_fd_list", + out = "open_fd_list.go", + package = "lisafs", + prefix = "openFD", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*OpenFD", + "Linker": "*OpenFD", + }, +) + +go_library( + name = "lisafs", + srcs = [ + "channel.go", + "client.go", + "client_file.go", + "communicator.go", + "connection.go", + "control_fd_list.go", + "control_fd_refs.go", + "fd.go", + "handlers.go", + "lisafs.go", + "message.go", + "open_fd_list.go", + "open_fd_refs.go", + "sample_message.go", + "server.go", + "sock.go", + ], + marshal = True, + deps = [ + "//pkg/abi/linux", + "//pkg/cleanup", + "//pkg/context", + "//pkg/fdchannel", + "//pkg/flipcall", + "//pkg/fspath", + "//pkg/hostarch", + "//pkg/log", + "//pkg/marshal/primitive", + "//pkg/p9", + "//pkg/refsvfs2", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "sock_test", + size = "small", + srcs = ["sock_test.go"], + library = ":lisafs", + deps = [ + "//pkg/marshal", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "connection_test", + size = "small", + srcs = ["connection_test.go"], + deps = [ + ":lisafs", + "//pkg/sync", + "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md index 51d0d40e5..6b857321a 100644 --- a/pkg/lisafs/README.md +++ b/pkg/lisafs/README.md @@ -1,5 +1,8 @@ # Replacing 9P +NOTE: LISAFS is **NOT** production ready. There are still some security concerns +that must be resolved first. + ## Background The Linux filesystem model consists of the following key aspects (modulo mounts, diff --git a/pkg/lisafs/channel.go b/pkg/lisafs/channel.go new file mode 100644 index 000000000..301212e51 --- /dev/null +++ b/pkg/lisafs/channel.go @@ -0,0 +1,190 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math" + "runtime" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +var ( + chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes()) +) + +// maxChannels returns the number of channels a client can create. +// +// The server will reject channel creation requests beyond this (per client). +// Note that we don't want the number of channels to be too large, because each +// accounts for a large region of shared memory. +// TODO(gvisor.dev/issue/6313): Tune the number of channels. +func maxChannels() int { + maxChans := runtime.GOMAXPROCS(0) + if maxChans < 2 { + maxChans = 2 + } + if maxChans > 4 { + maxChans = 4 + } + return maxChans +} + +// channel implements Communicator and represents the communication endpoint +// for the client and server and is used to perform fast IPC. Apart from +// communicating data, a channel is also capable of donating file descriptors. +type channel struct { + fdTracker + dead bool + data flipcall.Endpoint + fdChan fdchannel.Endpoint +} + +var _ Communicator = (*channel)(nil) + +// PayloadBuf implements Communicator.PayloadBuf. +func (ch *channel) PayloadBuf(size uint32) []byte { + return ch.data.Data()[chanHeaderLen : chanHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + // Write header. Requests can not donate FDs. + ch.marshalHdr(m, 0 /* numFDs */) + + // One-shot communication. RPCs are expected to be quick rather than block. + rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen) + if err != nil { + // This channel is now unusable. + ch.dead = true + // Map the transport errors to EIO, but also log the real error. + log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err) + return 0, 0, unix.EIO + } + + return ch.rcvMsg(rcvDataLen) +} + +func (ch *channel) shutdown() { + ch.data.Shutdown() +} + +func (ch *channel) destroy() { + ch.dead = true + ch.fdChan.Destroy() + ch.data.Destroy() +} + +// createChannel creates a server side channel. It returns a packet window +// descriptor (for the data channel) and an open socket for the FD channel. +func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + // If c.channels is nil, the connection has closed. + if c.channels == nil || len(c.channels) >= maxChannels() { + return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS + } + ch := &channel{} + + // Set up data channel. + desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize)) + if err != nil { + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + if err := ch.data.Init(flipcall.ServerSide, desc); err != nil { + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + + // Set up FD channel. + fdSocks, err := fdchannel.NewConnectedSockets() + if err != nil { + ch.data.Destroy() + return nil, flipcall.PacketWindowDescriptor{}, -1, err + } + ch.fdChan.Init(fdSocks[0]) + clientFDSock := fdSocks[1] + + c.channels = append(c.channels, ch) + return ch, desc, clientFDSock, nil +} + +// sendFDs sends as many FDs as it can. The failure to send an FD does not +// cause an error and fail the entire RPC. FDs are considered supplementary +// responses that are not critical to the RPC response itself. The failure to +// send the (i)th FD will cause all the following FDs to not be sent as well +// because the order in which FDs are donated is important. +func (ch *channel) sendFDs(fds []int) uint8 { + numFDs := len(fds) + if numFDs == 0 { + return 0 + } + + if numFDs > math.MaxUint8 { + log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs) + return 0 + } + + for i, fd := range fds { + if err := ch.fdChan.SendFD(fd); err != nil { + log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err) + return uint8(i) + } + } + return uint8(numFDs) +} + +// channelHeader is the header present in front of each message received on +// flipcall endpoint when the protocol version being used is 1. +// +// +marshal +type channelHeader struct { + message MID + numFDs uint8 + _ uint8 // Need to make struct packed. +} + +func (ch *channel) marshalHdr(m MID, numFDs uint8) { + header := &channelHeader{ + message: m, + numFDs: numFDs, + } + header.MarshalUnsafe(ch.data.Data()) +} + +func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) { + if dataLen < chanHeaderLen { + log.Warningf("received data has size smaller than header length: %d", dataLen) + return 0, 0, unix.EIO + } + + // Read header first. + var header channelHeader + header.UnmarshalUnsafe(ch.data.Data()) + + // Read any FDs. + for i := 0; i < int(header.numFDs); i++ { + fd, err := ch.fdChan.RecvFDNonblock() + if err != nil { + log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err) + break + } + ch.TrackFD(fd) + } + + return header.message, dataLen - chanHeaderLen, nil +} diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go new file mode 100644 index 000000000..ccf1b9f72 --- /dev/null +++ b/pkg/lisafs/client.go @@ -0,0 +1,432 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + "math" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +const ( + // fdsToCloseBatchSize is the number of closed FDs batched before an Close + // RPC is made to close them all. fdsToCloseBatchSize is immutable. + fdsToCloseBatchSize = 100 +) + +// Client helps manage a connection to the lisafs server and pass messages +// efficiently. There is a 1:1 mapping between a Connection and a Client. +type Client struct { + // sockComm is the main socket by which this connections is established. + // Communication over the socket is synchronized by sockMu. + sockMu sync.Mutex + sockComm *sockCommunicator + + // channelsMu protects channels and availableChannels. + channelsMu sync.Mutex + // channels tracks all the channels. + channels []*channel + // availableChannels is a LIFO (stack) of channels available to be used. + availableChannels []*channel + // activeWg represents active channels. + activeWg sync.WaitGroup + + // watchdogWg only holds the watchdog goroutine. + watchdogWg sync.WaitGroup + + // supported caches information about which messages are supported. It is + // indexed by MID. An MID is supported if supported[MID] is true. + supported []bool + + // maxMessageSize is the maximum payload length (in bytes) that can be sent. + // It is initialized on Mount and is immutable. + maxMessageSize uint32 + + // fdsToClose tracks the FDs to close. It caches the FDs no longer being used + // by the client and closes them in one shot. It is not preserved across + // checkpoint/restore as FDIDs are not preserved. + fdsMu sync.Mutex + fdsToClose []FDID +} + +// NewClient creates a new client for communication with the server. It mounts +// the server and creates channels for fast IPC. NewClient takes ownership over +// the passed socket. On success, it returns the initialized client along with +// the root Inode. +func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) { + maxChans := maxChannels() + c := &Client{ + sockComm: newSockComm(sock), + channels: make([]*channel, 0, maxChans), + availableChannels: make([]*channel, 0, maxChans), + maxMessageSize: 1 << 20, // 1 MB for now. + fdsToClose: make([]FDID, 0, fdsToCloseBatchSize), + } + + // Start a goroutine to check socket health. This goroutine is also + // responsible for client cleanup. + c.watchdogWg.Add(1) + go c.watchdog() + + // Clean everything up if anything fails. + cu := cleanup.Make(func() { + c.Close() + }) + defer cu.Clean() + + // Mount the server first. Assume Mount is supported so that we can make the + // Mount RPC below. + c.supported = make([]bool, Mount+1) + c.supported[Mount] = true + mountMsg := MountReq{ + MountPath: SizedString(mountPath), + } + var mountResp MountResp + if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil { + return nil, nil, err + } + + // Initialize client. + c.maxMessageSize = uint32(mountResp.MaxMessageSize) + var maxSuppMID MID + for _, suppMID := range mountResp.SupportedMs { + if suppMID > maxSuppMID { + maxSuppMID = suppMID + } + } + c.supported = make([]bool, maxSuppMID+1) + for _, suppMID := range mountResp.SupportedMs { + c.supported[suppMID] = true + } + + // Create channels parallely so that channels can be used to create more + // channels and costly initialization like flipcall.Endpoint.Connect can + // proceed parallely. + var channelsWg sync.WaitGroup + channelErrs := make([]error, maxChans) + for i := 0; i < maxChans; i++ { + channelsWg.Add(1) + curChanID := i + go func() { + defer channelsWg.Done() + ch, err := c.createChannel() + if err != nil { + log.Warningf("channel creation failed: %v", err) + channelErrs[curChanID] = err + return + } + c.channelsMu.Lock() + c.channels = append(c.channels, ch) + c.availableChannels = append(c.availableChannels, ch) + c.channelsMu.Unlock() + }() + } + channelsWg.Wait() + + for _, channelErr := range channelErrs { + // Return the first non-nil channel creation error. + if channelErr != nil { + return nil, nil, channelErr + } + } + cu.Release() + + return c, &mountResp.Root, nil +} + +func (c *Client) watchdog() { + defer c.watchdogWg.Done() + + events := []unix.PollFd{ + { + Fd: int32(c.sockComm.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == unix.EINTR || err == unix.EAGAIN { + continue + } + if err != nil { + log.Warningf("lisafs.Client.watch(): %v", err) + } else if n != 1 { + log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Shutdown all active channels and wait for them to complete. + c.shutdownActiveChans() + c.activeWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.destroy() + } + c.channelsMu.Unlock() + + // Close main socket. + c.sockComm.destroy() +} + +func (c *Client) shutdownActiveChans() { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + + availableChans := make(map[*channel]bool) + for _, ch := range c.availableChannels { + availableChans[ch] = true + } + for _, ch := range c.channels { + // A channel that is not available is active. + if _, ok := availableChans[ch]; !ok { + log.Debugf("shutting down active channel@%p...", ch) + ch.shutdown() + } + } + + // Prevent channels from becoming available and serving new requests. + c.availableChannels = nil +} + +// Close shuts down the main socket and waits for the watchdog to clean up. +func (c *Client) Close() { + // This shutdown has no effect if the watchdog has already fired and closed + // the main socket. + c.sockComm.shutdown() + c.watchdogWg.Wait() +} + +func (c *Client) createChannel() (*channel, error) { + var chanResp ChannelResp + var fds [2]int + if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil { + return nil, err + } + if fds[0] < 0 || fds[1] < 0 { + closeFDs(fds[:]) + return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds) + } + + // Lets create the channel. + defer closeFDs(fds[:1]) // The data FD is not needed after this. + desc := flipcall.PacketWindowDescriptor{ + FD: fds[0], + Offset: chanResp.dataOffset, + Length: int(chanResp.dataLength), + } + + ch := &channel{} + if err := ch.data.Init(flipcall.ClientSide, desc); err != nil { + closeFDs(fds[1:]) + return nil, err + } + ch.fdChan.Init(fds[1]) // fdChan now owns this FD. + + // Only a connected channel is usable. + if err := ch.data.Connect(); err != nil { + ch.destroy() + return nil, err + } + return ch, nil +} + +// IsSupported returns true if this connection supports the passed message. +func (c *Client) IsSupported(m MID) bool { + return int(m) < len(c.supported) && c.supported[m] +} + +// CloseFDBatched either queues the passed FD to be closed or makes a batch +// RPC to close all the accumulated FDs-to-close. +func (c *Client) CloseFDBatched(ctx context.Context, fd FDID) { + c.fdsMu.Lock() + c.fdsToClose = append(c.fdsToClose, fd) + if len(c.fdsToClose) < fdsToCloseBatchSize { + c.fdsMu.Unlock() + return + } + + // Flush the cache. We should not hold fdsMu while making an RPC, so be sure + // to copy the fdsToClose to another buffer before unlocking fdsMu. + var toCloseArr [fdsToCloseBatchSize]FDID + toClose := toCloseArr[:len(c.fdsToClose)] + copy(toClose, c.fdsToClose) + + // Clear fdsToClose so other FDIDs can be appended. + c.fdsToClose = c.fdsToClose[:0] + c.fdsMu.Unlock() + + req := CloseReq{FDs: toClose} + ctx.UninterruptibleSleepStart(false) + err := c.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + log.Warningf("lisafs: batch closing FDs returned error: %v", err) + } +} + +// SyncFDs makes a Fsync RPC to sync multiple FDs. +func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error { + if len(fds) == 0 { + return nil + } + req := FsyncReq{FDs: fds} + ctx.UninterruptibleSleepStart(false) + err := c.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// SndRcvMessage invokes reqMarshal to marshal the request onto the payload +// buffer, wakes up the server to process the request, waits for the response +// and invokes respUnmarshal with the response payload. respFDs is populated +// with the received FDs, extra fields are set to -1. +// +// Note that the function arguments intentionally accept marshal.Marshallable +// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of +// directly accepting the marshal.Marshallable interface. Even though just +// accepting marshal.Marshallable is cleaner, it leads to a heap allocation +// (even if that interface variable itself does not escape). In other words, +// implicit conversion to an interface leads to an allocation. +// +// Precondition: reqMarshal and respUnmarshal must be non-nil. +func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error { + if !c.IsSupported(m) { + return unix.EOPNOTSUPP + } + if payloadLen > c.maxMessageSize { + log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize) + return unix.EIO + } + wantFDs := len(respFDs) + if wantFDs > math.MaxUint8 { + log.Warningf("want too many FDs: %d", wantFDs) + return unix.EINVAL + } + + // Acquire a communicator. + comm := c.acquireCommunicator() + defer c.releaseCommunicator(comm) + + // Marshal the request into comm's payload buffer and make the RPC. + reqMarshal(comm.PayloadBuf(payloadLen)) + respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs)) + + // Handle FD donation. + rcvFDs := comm.ReleaseFDs() + if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 { + // releasedFDs is memory owned by comm which can not be returned to caller. + // Copy it into the caller's buffer. + numFDCopied := copy(respFDs, rcvFDs) + if numFDCopied < numRcvFDs { + log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs) + closeFDs(rcvFDs[numFDCopied:]) + } + if numFDCopied < wantFDs { + for i := numFDCopied; i < wantFDs; i++ { + respFDs[i] = -1 + } + } + } + + // Error cases. + if err != nil { + closeFDs(respFDs) + return err + } + if respM == Error { + closeFDs(respFDs) + var resp ErrorResp + resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen)) + return unix.Errno(resp.errno) + } + if respM != m { + closeFDs(respFDs) + log.Warningf("sent %d message but got %d in response", m, respM) + return unix.EINVAL + } + + // Success. The payload must be unmarshalled *before* comm is released. + respUnmarshal(comm.PayloadBuf(respPayloadLen)) + return nil +} + +// Postcondition: releaseCommunicator() must be called on the returned value. +func (c *Client) acquireCommunicator() Communicator { + // Prefer using channel over socket because: + // - Channel uses a shared memory region for passing messages. IO from shared + // memory is faster and does not involve making a syscall. + // - No intermediate buffer allocation needed. With a channel, the message + // can be directly pasted into the shared memory region. + if ch := c.getChannel(); ch != nil { + return ch + } + + c.sockMu.Lock() + return c.sockComm +} + +// Precondition: comm must have been acquired via acquireCommunicator(). +func (c *Client) releaseCommunicator(comm Communicator) { + switch t := comm.(type) { + case *sockCommunicator: + c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator(). + case *channel: + c.releaseChannel(t) + default: + panic(fmt.Sprintf("unknown communicator type %T", t)) + } +} + +// getChannel pops a channel from the available channels stack. The caller must +// release the channel after use. +func (c *Client) getChannel() *channel { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + if len(c.availableChannels) == 0 { + return nil + } + + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + c.activeWg.Add(1) + return ch +} + +// releaseChannel pushes the passed channel onto the available channel stack if +// reinsert is true. +func (c *Client) releaseChannel(ch *channel) { + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + + // If availableChannels is nil, then watchdog has fired and the client is + // shutting down. So don't make this channel available again. + if !ch.dead && c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.activeWg.Done() +} diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go new file mode 100644 index 000000000..170c15705 --- /dev/null +++ b/pkg/lisafs/client_file.go @@ -0,0 +1,528 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// ClientFD is a wrapper around FDID that provides client-side utilities +// so that RPC making is easier. +type ClientFD struct { + fd FDID + client *Client +} + +// ID returns the underlying FDID. +func (f *ClientFD) ID() FDID { + return f.fd +} + +// Client returns the backing Client. +func (f *ClientFD) Client() *Client { + return f.client +} + +// NewFD initializes a new ClientFD. +func (c *Client) NewFD(fd FDID) ClientFD { + return ClientFD{ + client: c, + fd: fd, + } +} + +// Ok returns true if the underlying FD is ok. +func (f *ClientFD) Ok() bool { + return f.fd.Ok() +} + +// CloseBatched queues this FD to be closed on the server and resets f.fd. +// This maybe invoke the Close RPC if the queue is full. +func (f *ClientFD) CloseBatched(ctx context.Context) { + f.client.CloseFDBatched(ctx, f.fd) + f.fd = InvalidFDID +} + +// Close closes this FD immediately (invoking a Close RPC). Consider using +// CloseBatched if closing this FD on remote right away is not critical. +func (f *ClientFD) Close(ctx context.Context) error { + fdArr := [1]FDID{f.fd} + req := CloseReq{FDs: fdArr[:]} + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// OpenAt makes the OpenAt RPC. +func (f *ClientFD) OpenAt(ctx context.Context, flags uint32) (FDID, int, error) { + req := OpenAtReq{ + FD: f.fd, + Flags: flags, + } + var respFD [1]int + var resp OpenAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(OpenAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, respFD[:]) + ctx.UninterruptibleSleepFinish(false) + return resp.NewFD, respFD[0], err +} + +// OpenCreateAt makes the OpenCreateAt RPC. +func (f *ClientFD) OpenCreateAt(ctx context.Context, name string, flags uint32, mode linux.FileMode, uid UID, gid GID) (Inode, FDID, int, error) { + var req OpenCreateAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Flags = primitive.Uint32(flags) + req.Mode = mode + req.UID = uid + req.GID = gid + + var respFD [1]int + var resp OpenCreateAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(OpenCreateAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, respFD[:]) + ctx.UninterruptibleSleepFinish(false) + return resp.Child, resp.NewFD, respFD[0], err +} + +// StatTo makes the Fstat RPC and populates stat with the result. +func (f *ClientFD) StatTo(ctx context.Context, stat *linux.Statx) error { + req := StatReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FStat, uint32(req.SizeBytes()), req.MarshalUnsafe, stat.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Sync makes the Fsync RPC. +func (f *ClientFD) Sync(ctx context.Context) error { + req := FsyncReq{FDs: []FDID{f.fd}} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// chunkify applies fn to buf in chunks based on chunkSize. +func chunkify(chunkSize uint64, buf []byte, fn func([]byte, uint64) (uint64, error)) (uint64, error) { + toProcess := uint64(len(buf)) + var ( + totalProcessed uint64 + curProcessed uint64 + off uint64 + err error + ) + for { + if totalProcessed == toProcess { + return totalProcessed, nil + } + + if totalProcessed+chunkSize > toProcess { + curProcessed, err = fn(buf[totalProcessed:], off) + } else { + curProcessed, err = fn(buf[totalProcessed:totalProcessed+chunkSize], off) + } + totalProcessed += curProcessed + off += curProcessed + + if err != nil { + return totalProcessed, err + } + + // Return partial result immediately. + if curProcessed < chunkSize { + return totalProcessed, nil + } + + // If we received more bytes than we ever requested, this is a problem. + if totalProcessed > toProcess { + panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", totalProcessed, toProcess)) + } + } +} + +// Read makes the PRead RPC. +func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) { + var resp PReadResp + // maxDataReadSize represents the maximum amount of data we can read at once + // (maximum message size - metadata size present in resp). Uninitialized + // resp.SizeBytes() correctly returns the metadata size only (since the read + // buffer is empty). + maxDataReadSize := uint64(f.client.maxMessageSize) - uint64(resp.SizeBytes()) + return chunkify(maxDataReadSize, dst, func(buf []byte, curOff uint64) (uint64, error) { + req := PReadReq{ + Offset: offset + curOff, + FD: f.fd, + Count: uint32(len(buf)), + } + + // This will be unmarshalled into. Already set Buf so that we don't need to + // allocate a temporary buffer during unmarshalling. + // PReadResp.UnmarshalBytes expects this to be set. + resp.Buf = buf + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return uint64(resp.NumBytes), err + }) +} + +// Write makes the PWrite RPC. +func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) { + var req PWriteReq + // maxDataWriteSize represents the maximum amount of data we can write at + // once (maximum message size - metadata size present in req). Uninitialized + // req.SizeBytes() correctly returns the metadata size only (since the write + // buffer is empty). + maxDataWriteSize := uint64(f.client.maxMessageSize) - uint64(req.SizeBytes()) + return chunkify(maxDataWriteSize, src, func(buf []byte, curOff uint64) (uint64, error) { + req = PWriteReq{ + Offset: primitive.Uint64(offset + curOff), + FD: f.fd, + NumBytes: primitive.Uint32(len(buf)), + Buf: buf, + } + + var resp PWriteResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Count, err + }) +} + +// MkdirAt makes the MkdirAt RPC. +func (f *ClientFD) MkdirAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID) (*Inode, error) { + var req MkdirAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Mode = mode + req.UID = uid + req.GID = gid + + var resp MkdirAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(MkdirAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.ChildDir, err +} + +// SymlinkAt makes the SymlinkAt RPC. +func (f *ClientFD) SymlinkAt(ctx context.Context, name, target string, uid UID, gid GID) (*Inode, error) { + req := SymlinkAtReq{ + DirFD: f.fd, + Name: SizedString(name), + Target: SizedString(target), + UID: uid, + GID: gid, + } + + var resp SymlinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(SymlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Symlink, err +} + +// LinkAt makes the LinkAt RPC. +func (f *ClientFD) LinkAt(ctx context.Context, targetFD FDID, name string) (*Inode, error) { + req := LinkAtReq{ + DirFD: f.fd, + Target: targetFD, + Name: SizedString(name), + } + + var resp LinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(LinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Link, err +} + +// MknodAt makes the MknodAt RPC. +func (f *ClientFD) MknodAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID, minor, major uint32) (*Inode, error) { + var req MknodAtReq + req.DirFD = f.fd + req.Name = SizedString(name) + req.Mode = mode + req.UID = uid + req.GID = gid + req.Minor = primitive.Uint32(minor) + req.Major = primitive.Uint32(major) + + var resp MknodAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(MknodAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return &resp.Child, err +} + +// SetStat makes the SetStat RPC. +func (f *ClientFD) SetStat(ctx context.Context, stat *linux.Statx) (uint32, error, error) { + req := SetStatReq{ + FD: f.fd, + Mask: stat.Mask, + Mode: uint32(stat.Mode), + UID: UID(stat.UID), + GID: GID(stat.GID), + Size: stat.Size, + Atime: linux.Timespec{ + Sec: stat.Atime.Sec, + Nsec: int64(stat.Atime.Nsec), + }, + Mtime: linux.Timespec{ + Sec: stat.Mtime.Sec, + Nsec: int64(stat.Mtime.Nsec), + }, + } + + var resp SetStatResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(SetStat, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.FailureMask, unix.Errno(resp.FailureErrNo), err +} + +// WalkMultiple makes the Walk RPC with multiple path components. +func (f *ClientFD) WalkMultiple(ctx context.Context, names []string) (WalkStatus, []Inode, error) { + req := WalkReq{ + DirFD: f.fd, + Path: StringArray(names), + } + + var resp WalkResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Status, resp.Inodes, err +} + +// Walk makes the Walk RPC with just one path component to walk. +func (f *ClientFD) Walk(ctx context.Context, name string) (*Inode, error) { + req := WalkReq{ + DirFD: f.fd, + Path: []string{name}, + } + + var inode [1]Inode + resp := WalkResp{Inodes: inode[:]} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return nil, err + } + + switch resp.Status { + case WalkComponentDoesNotExist: + return nil, unix.ENOENT + case WalkComponentSymlink: + // f is not a directory which can be walked on. + return nil, unix.ENOTDIR + } + + if n := len(resp.Inodes); n > 1 { + for i := range resp.Inodes { + f.client.CloseFDBatched(ctx, resp.Inodes[i].ControlFD) + } + log.Warningf("requested to walk one component, but got %d results", n) + return nil, unix.EIO + } else if n == 0 { + log.Warningf("walk has success status but no results returned") + return nil, unix.ENOENT + } + return &inode[0], err +} + +// WalkStat makes the WalkStat RPC with multiple path components to walk. +func (f *ClientFD) WalkStat(ctx context.Context, names []string) ([]linux.Statx, error) { + req := WalkReq{ + DirFD: f.fd, + Path: StringArray(names), + } + + var resp WalkStatResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(WalkStat, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Stats, err +} + +// StatFSTo makes the FStatFS RPC and populates statFS with the result. +func (f *ClientFD) StatFSTo(ctx context.Context, statFS *StatFS) error { + req := FStatFSReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FStatFS, uint32(req.SizeBytes()), req.MarshalUnsafe, statFS.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Allocate makes the FAllocate RPC. +func (f *ClientFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + req := FAllocateReq{ + FD: f.fd, + Mode: mode, + Offset: offset, + Length: length, + } + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FAllocate, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// ReadLinkAt makes the ReadLinkAt RPC. +func (f *ClientFD) ReadLinkAt(ctx context.Context) (string, error) { + req := ReadLinkAtReq{FD: f.fd} + var resp ReadLinkAtResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(ReadLinkAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return string(resp.Target), err +} + +// Flush makes the Flush RPC. +func (f *ClientFD) Flush(ctx context.Context) error { + if !f.client.IsSupported(Flush) { + // If Flush is not supported, it probably means that it would be a noop. + return nil + } + req := FlushReq{FD: f.fd} + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Flush, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Connect makes the Connect RPC. +func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType) (int, error) { + req := ConnectReq{FD: f.fd, SockType: uint32(sockType)} + var sockFD [1]int + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, sockFD[:]) + ctx.UninterruptibleSleepFinish(false) + if err == nil && sockFD[0] < 0 { + err = unix.EBADF + } + return sockFD[0], err +} + +// UnlinkAt makes the UnlinkAt RPC. +func (f *ClientFD) UnlinkAt(ctx context.Context, name string, flags uint32) error { + req := UnlinkAtReq{ + DirFD: f.fd, + Name: SizedString(name), + Flags: primitive.Uint32(flags), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(UnlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// RenameTo makes the RenameAt RPC which renames f to newDirFD directory with +// name newName. +func (f *ClientFD) RenameTo(ctx context.Context, newDirFD FDID, newName string) error { + req := RenameAtReq{ + Renamed: f.fd, + NewDir: newDirFD, + NewName: SizedString(newName), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(RenameAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// Getdents64 makes the Getdents64 RPC. +func (f *ClientFD) Getdents64(ctx context.Context, count int32) ([]Dirent64, error) { + req := Getdents64Req{ + DirFD: f.fd, + Count: count, + } + + var resp Getdents64Resp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(Getdents64, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Dirents, err +} + +// ListXattr makes the FListXattr RPC. +func (f *ClientFD) ListXattr(ctx context.Context, size uint64) ([]string, error) { + req := FListXattrReq{ + FD: f.fd, + Size: size, + } + + var resp FListXattrResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FListXattr, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Xattrs, err +} + +// GetXattr makes the FGetXattr RPC. +func (f *ClientFD) GetXattr(ctx context.Context, name string, size uint64) (string, error) { + req := FGetXattrReq{ + FD: f.fd, + Name: SizedString(name), + BufSize: primitive.Uint32(size), + } + + var resp FGetXattrResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FGetXattr, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return string(resp.Value), err +} + +// SetXattr makes the FSetXattr RPC. +func (f *ClientFD) SetXattr(ctx context.Context, name string, value string, flags uint32) error { + req := FSetXattrReq{ + FD: f.fd, + Name: SizedString(name), + Value: SizedString(value), + Flags: primitive.Uint32(flags), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FSetXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} + +// RemoveXattr makes the FRemoveXattr RPC. +func (f *ClientFD) RemoveXattr(ctx context.Context, name string) error { + req := FRemoveXattrReq{ + FD: f.fd, + Name: SizedString(name), + } + + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(FRemoveXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil) + ctx.UninterruptibleSleepFinish(false) + return err +} diff --git a/pkg/lisafs/communicator.go b/pkg/lisafs/communicator.go new file mode 100644 index 000000000..ec2035158 --- /dev/null +++ b/pkg/lisafs/communicator.go @@ -0,0 +1,80 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import "golang.org/x/sys/unix" + +// Communicator is a server side utility which represents exactly how the +// server is communicating with the client. +type Communicator interface { + // PayloadBuf returns a slice to the payload section of its internal buffer + // where the message can be marshalled. The handlers should use this to + // populate the payload buffer with the message. + // + // The payload buffer contents *should* be preserved across calls with + // different sizes. Note that this is not a guarantee, because a compromised + // owner of a "shared" payload buffer can tamper with its contents anytime, + // even when it's not its turn to do so. + PayloadBuf(size uint32) []byte + + // SndRcvMessage sends message m. The caller must have populated PayloadBuf() + // with payloadLen bytes. The caller expects to receive wantFDs FDs. + // Any received FDs must be accessible via ReleaseFDs(). It returns the + // response message along with the response payload length. + SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) + + // DonateFD makes fd non-blocking and starts tracking it. The next call to + // ReleaseFDs will include fd in the order it was added. Communicator takes + // ownership of fd. Server side should call this. + DonateFD(fd int) error + + // Track starts tracking fd. The next call to ReleaseFDs will include fd in + // the order it was added. Communicator takes ownership of fd. Client side + // should use this for accumulating received FDs. + TrackFD(fd int) + + // ReleaseFDs returns the accumulated FDs and stops tracking them. The + // ownership of the FDs is transferred to the caller. + ReleaseFDs() []int +} + +// fdTracker is a partial implementation of Communicator. It can be embedded in +// Communicator implementations to keep track of FD donations. +type fdTracker struct { + fds []int +} + +// DonateFD implements Communicator.DonateFD. +func (d *fdTracker) DonateFD(fd int) error { + // Make sure the FD is non-blocking. + if err := unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return err + } + d.TrackFD(fd) + return nil +} + +// TrackFD implements Communicator.TrackFD. +func (d *fdTracker) TrackFD(fd int) { + d.fds = append(d.fds, fd) +} + +// ReleaseFDs implements Communicator.ReleaseFDs. +func (d *fdTracker) ReleaseFDs() []int { + ret := d.fds + d.fds = d.fds[:0] + return ret +} diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go new file mode 100644 index 000000000..f6e5ecb4f --- /dev/null +++ b/pkg/lisafs/connection.go @@ -0,0 +1,320 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// Connection represents a connection between a mount point in the client and a +// mount point in the server. It is owned by the server on which it was started +// and facilitates communication with the client mount. +// +// Each connection is set up using a unix domain socket. One end is owned by +// the server and the other end is owned by the client. The connection may +// spawn additional comunicational channels for the same mount for increased +// RPC concurrency. +type Connection struct { + // server is the server on which this connection was created. It is immutably + // associated with it for its entire lifetime. + server *Server + + // mounted is a one way flag indicating whether this connection has been + // mounted correctly and the server is initialized properly. + mounted bool + + // readonly indicates if this connection is readonly. All write operations + // will fail with EROFS. + readonly bool + + // sockComm is the main socket by which this connections is established. + sockComm *sockCommunicator + + // channelsMu protects channels. + channelsMu sync.Mutex + // channels keeps track of all open channels. + channels []*channel + + // activeWg represents active channels. + activeWg sync.WaitGroup + + // reqGate counts requests that are still being handled. + reqGate sync.Gate + + // channelAlloc is used to allocate memory for channels. + channelAlloc *flipcall.PacketWindowAllocator + + fdsMu sync.RWMutex + // fds keeps tracks of open FDs on this server. It is protected by fdsMu. + fds map[FDID]genericFD + // nextFDID is the next available FDID. It is protected by fdsMu. + nextFDID FDID +} + +// CreateConnection initializes a new connection - creating a server if +// required. The connection must be started separately. +func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) { + c := &Connection{ + sockComm: newSockComm(sock), + server: s, + readonly: readonly, + channels: make([]*channel, 0, maxChannels()), + fds: make(map[FDID]genericFD), + nextFDID: InvalidFDID + 1, + } + + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return nil, err + } + c.channelAlloc = alloc + return c, nil +} + +// Server returns the associated server. +func (c *Connection) Server() *Server { + return c.server +} + +// ServerImpl returns the associated server implementation. +func (c *Connection) ServerImpl() ServerImpl { + return c.server.impl +} + +// Run defines the lifecycle of a connection. +func (c *Connection) Run() { + defer c.close() + + // Start handling requests on this connection. + for { + m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */) + if err != nil { + log.Debugf("sock read failed, closing connection: %v", err) + return + } + + respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen) + err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs) + closeFDs(respFDs) + if err != nil { + log.Debugf("sock write failed, closing connection: %v", err) + return + } + } +} + +// service starts servicing the passed channel until the channel is shutdown. +// This is a blocking method and hence must be called in a separate goroutine. +func (c *Connection) service(ch *channel) error { + rcvDataLen, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rcvDataLen > 0 { + m, payloadLen, err := ch.rcvMsg(rcvDataLen) + if err != nil { + return err + } + respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen) + numFDs := ch.sendFDs(respFDs) + closeFDs(respFDs) + + ch.marshalHdr(respM, numFDs) + rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen) + if err != nil { + return err + } + } + return nil +} + +func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) { + resp := &ErrorResp{errno: uint32(err)} + respLen := uint32(resp.SizeBytes()) + resp.MarshalUnsafe(comm.PayloadBuf(respLen)) + return Error, respLen, nil +} + +func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) { + if !c.reqGate.Enter() { + // c.close() has been called; the connection is shutting down. + return c.respondError(comm, unix.ECONNRESET) + } + defer c.reqGate.Leave() + + if !c.mounted && m != Mount { + log.Warningf("connection must first be mounted") + return c.respondError(comm, unix.EINVAL) + } + + // Check if the message is supported for forward compatibility. + if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil { + log.Warningf("received request which is not supported by the server, MID = %d", m) + return c.respondError(comm, unix.EOPNOTSUPP) + } + + // Try handling the request. + respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen) + fds := comm.ReleaseFDs() + if err != nil { + closeFDs(fds) + return c.respondError(comm, p9.ExtractErrno(err)) + } + + return m, respPayloadLen, fds +} + +func (c *Connection) close() { + // Wait for completion of all inflight requests. This is mostly so that if + // a request is stuck, the sandbox supervisor has the opportunity to kill + // us with SIGABRT to get a stack dump of the offending handler. + c.reqGate.Close() + + // Shutdown and clean up channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.shutdown() + } + c.activeWg.Wait() + for _, ch := range c.channels { + ch.destroy() + } + // This is to prevent additional channels from being created. + c.channels = nil + c.channelsMu.Unlock() + + // Free the channel memory. + if c.channelAlloc != nil { + c.channelAlloc.Destroy() + } + + // Ensure the connection is closed. + c.sockComm.destroy() + + // Cleanup all FDs. + c.fdsMu.Lock() + for fdid := range c.fds { + fd := c.removeFDLocked(fdid) + fd.DecRef(nil) // Drop the ref held by c. + } + c.fdsMu.Unlock() +} + +// The caller gains a ref on the FD on success. +func (c *Connection) lookupFD(id FDID) (genericFD, error) { + c.fdsMu.RLock() + defer c.fdsMu.RUnlock() + + fd, ok := c.fds[id] + if !ok { + return nil, unix.EBADF + } + fd.IncRef() + return fd, nil +} + +// LookupControlFD retrieves the control FD identified by id on this +// connection. On success, the caller gains a ref on the FD. +func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) { + fd, err := c.lookupFD(id) + if err != nil { + return nil, err + } + + cfd, ok := fd.(*ControlFD) + if !ok { + fd.DecRef(nil) + return nil, unix.EINVAL + } + return cfd, nil +} + +// LookupOpenFD retrieves the open FD identified by id on this +// connection. On success, the caller gains a ref on the FD. +func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) { + fd, err := c.lookupFD(id) + if err != nil { + return nil, err + } + + ofd, ok := fd.(*OpenFD) + if !ok { + fd.DecRef(nil) + return nil, unix.EINVAL + } + return ofd, nil +} + +// insertFD inserts the passed fd into the internal datastructure to track FDs. +// The caller must hold a ref on fd which is transferred to the connection. +func (c *Connection) insertFD(fd genericFD) FDID { + c.fdsMu.Lock() + defer c.fdsMu.Unlock() + + res := c.nextFDID + c.nextFDID++ + if c.nextFDID < res { + panic("ran out of FDIDs") + } + c.fds[res] = fd + return res +} + +// RemoveFD makes c stop tracking the passed FDID and drops its ref on it. +func (c *Connection) RemoveFD(id FDID) { + c.fdsMu.Lock() + fd := c.removeFDLocked(id) + c.fdsMu.Unlock() + if fd != nil { + // Drop the ref held by c. This can take arbitrarily long. So do not hold + // c.fdsMu while calling it. + fd.DecRef(nil) + } +} + +// RemoveControlFDLocked is the same as RemoveFD with added preconditions. +// +// Preconditions: +// * server's rename mutex must at least be read locked. +// * id must be pointing to a control FD. +func (c *Connection) RemoveControlFDLocked(id FDID) { + c.fdsMu.Lock() + fd := c.removeFDLocked(id) + c.fdsMu.Unlock() + if fd != nil { + // Drop the ref held by c. This can take arbitrarily long. So do not hold + // c.fdsMu while calling it. + fd.(*ControlFD).DecRefLocked() + } +} + +// removeFDLocked makes c stop tracking the passed FDID. Note that the caller +// must drop ref on the returned fd (preferably without holding c.fdsMu). +// +// Precondition: c.fdsMu is locked. +func (c *Connection) removeFDLocked(id FDID) genericFD { + fd := c.fds[id] + if fd == nil { + log.Warningf("removeFDLocked called on non-existent FDID %d", id) + return nil + } + delete(c.fds, id) + return fd +} diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go new file mode 100644 index 000000000..28ba47112 --- /dev/null +++ b/pkg/lisafs/connection_test.go @@ -0,0 +1,194 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connection_test + +import ( + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/lisafs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +const ( + dynamicMsgID = lisafs.Channel + 1 + versionMsgID = dynamicMsgID + 1 +) + +var handlers = [...]lisafs.RPCHandler{ + lisafs.Error: lisafs.ErrorHandler, + lisafs.Mount: lisafs.MountHandler, + lisafs.Channel: lisafs.ChannelHandler, + dynamicMsgID: dynamicMsgHandler, + versionMsgID: versionHandler, +} + +// testServer implements lisafs.ServerImpl. +type testServer struct { + lisafs.Server +} + +var _ lisafs.ServerImpl = (*testServer)(nil) + +type testControlFD struct { + lisafs.ControlFD + lisafs.ControlFDImpl +} + +func (fd *testControlFD) FD() *lisafs.ControlFD { + return &fd.ControlFD +} + +// Mount implements lisafs.Mount. +func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) { + return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil +} + +// MaxMessageSize implements lisafs.MaxMessageSize. +func (s *testServer) MaxMessageSize() uint32 { + return lisafs.MaxMessageSize() +} + +// SupportedMessages implements lisafs.ServerImpl.SupportedMessages. +func (s *testServer) SupportedMessages() []lisafs.MID { + return []lisafs.MID{ + lisafs.Mount, + lisafs.Channel, + dynamicMsgID, + versionMsgID, + } +} + +func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) { + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + + ts := &testServer{} + ts.Server.InitTestOnly(ts, handlers[:]) + conn, err := ts.CreateConnection(serverSocket, false /* readonly */) + if err != nil { + t.Fatalf("starting connection failed: %v", err) + return + } + ts.StartConnection(conn) + + c, _, err := lisafs.NewClient(clientSocket, "/") + if err != nil { + t.Fatalf("client creation failed: %v", err) + } + + clientFn(c) + + c.Close() // This should trigger client and server shutdown. + ts.Wait() +} + +// TestStartUp tests that the server and client can be started up correctly. +func TestStartUp(t *testing.T) { + runServerClient(t, func(c *lisafs.Client) { + if c.IsSupported(lisafs.Error) { + t.Errorf("sending error messages should not be supported") + } + }) +} + +func TestUnsupportedMessage(t *testing.T) { + unsupportedM := lisafs.MID(len(handlers)) + runServerClient(t, func(c *lisafs.Client) { + if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP { + t.Errorf("expected EOPNOTSUPP but got err: %v", err) + } + }) +} + +func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { + var req lisafs.MsgDynamic + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Just echo back the message. + respPayloadLen := uint32(req.SizeBytes()) + req.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// TestStress stress tests sending many messages from various goroutines. +func TestStress(t *testing.T) { + runServerClient(t, func(c *lisafs.Client) { + concurrency := 8 + numMsgPerGoroutine := 5000 + var clientWg sync.WaitGroup + for i := 0; i < concurrency; i++ { + clientWg.Add(1) + go func() { + defer clientWg.Done() + + for j := 0; j < numMsgPerGoroutine; j++ { + // Create a massive random message. + var req lisafs.MsgDynamic + req.Randomize(100) + + var resp lisafs.MsgDynamic + if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil { + t.Errorf("SndRcvMessage: received unexpected error %v", err) + return + } + if !reflect.DeepEqual(&req, &resp) { + t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp) + } + } + }() + } + + clientWg.Wait() + }) +} + +func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { + // To be fair, usually handlers will create their own objects and return a + // pointer to those. Might be tempting to reuse above variables, but don't. + var rv lisafs.P9Version + rv.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Create a new response. + sv := lisafs.P9Version{ + MSize: rv.MSize, + Version: "9P2000.L.Google.11", + } + respPayloadLen := uint32(sv.SizeBytes()) + sv.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel. +func BenchmarkSendRecv(b *testing.B) { + b.ReportAllocs() + sendV := lisafs.P9Version{ + MSize: 1 << 20, + Version: "9P2000.L.Google.12", + } + + var recvV lisafs.P9Version + runServerClient(b, func(c *lisafs.Client) { + for i := 0; i < b.N; i++ { + if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil { + b.Fatalf("unexpected error occurred: %v", err) + } + } + }) +} diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go new file mode 100644 index 000000000..cc6919a1b --- /dev/null +++ b/pkg/lisafs/fd.go @@ -0,0 +1,374 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/sync" +) + +// FDID (file descriptor identifier) is used to identify FDs on a connection. +// Each connection has its own FDID namespace. +// +// +marshal slice:FDIDSlice +type FDID uint32 + +// InvalidFDID represents an invalid FDID. +const InvalidFDID FDID = 0 + +// Ok returns true if f is a valid FDID. +func (f FDID) Ok() bool { + return f != InvalidFDID +} + +// genericFD can represent a ControlFD or OpenFD. +type genericFD interface { + refsvfs2.RefCounter +} + +// A ControlFD is the gateway to the backing filesystem tree node. It is an +// unusual concept. This exists to provide a safe way to do path-based +// operations on the file. It performs operations that can modify the +// filesystem tree and synchronizes these operations. See ControlFDImpl for +// supported operations. +// +// It is not an inode, because multiple control FDs are allowed to exist on the +// same file. It is not a file descriptor because it is not tied to any access +// mode, i.e. a control FD can change its access mode based on the operation +// being performed. +// +// Reference Model: +// * When a control FD is created, the connection takes a ref on it which +// represents the client's ref on the FD. +// * The client can drop its ref via the Close RPC which will in turn make the +// connection drop its ref. +// * Each control FD holds a ref on its parent for its entire life time. +type ControlFD struct { + controlFDRefs + controlFDEntry + + // parent is the parent directory FD containing the file this FD represents. + // A ControlFD holds a ref on parent for its entire lifetime. If this FD + // represents the root, then parent is nil. parent may be a control FD from + // another connection (another mount point). parent is protected by the + // backing server's rename mutex. + parent *ControlFD + + // name is the file path's last component name. If this FD represents the + // root directory, then name is the mount path. name is protected by the + // backing server's rename mutex. + name string + + // children is a linked list of all children control FDs. As per reference + // model, all children hold a ref on this FD. + // children is protected by childrenMu and server's rename mutex. To have + // mutual exclusion, it is sufficient to: + // * Hold rename mutex for reading and lock childrenMu. OR + // * Or hold rename mutex for writing. + childrenMu sync.Mutex + children controlFDList + + // openFDs is a linked list of all FDs opened on this FD. As per reference + // model, all open FDs hold a ref on this FD. + openFDsMu sync.RWMutex + openFDs openFDList + + // All the following fields are immutable. + + // id is the unique FD identifier which identifies this FD on its connection. + id FDID + + // conn is the backing connection owning this FD. + conn *Connection + + // ftype is the file type of the backing inode. ftype.FileType() == ftype. + ftype linux.FileMode + + // impl is the control FD implementation which embeds this struct. It + // contains all the implementation specific details. + impl ControlFDImpl +} + +var _ genericFD = (*ControlFD)(nil) + +// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context +// parameter should never be used. It exists solely to comply with the +// refsvfs2.RefCounter interface. +func (fd *ControlFD) DecRef(context.Context) { + fd.controlFDRefs.DecRef(func() { + if fd.parent != nil { + fd.conn.server.RenameMu.RLock() + fd.parent.childrenMu.Lock() + fd.parent.children.Remove(fd) + fd.parent.childrenMu.Unlock() + fd.conn.server.RenameMu.RUnlock() + fd.parent.DecRef(nil) // Drop the ref on the parent. + } + fd.impl.Close(fd.conn) + }) +} + +// DecRefLocked is the same as DecRef except the added precondition. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) DecRefLocked() { + fd.controlFDRefs.DecRef(func() { + fd.clearParentLocked() + fd.impl.Close(fd.conn) + }) +} + +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) clearParentLocked() { + if fd.parent == nil { + return + } + fd.parent.childrenMu.Lock() + fd.parent.children.Remove(fd) + fd.parent.childrenMu.Unlock() + fd.parent.DecRefLocked() // Drop the ref on the parent. +} + +// Init must be called before first use of fd. It inserts fd into the +// filesystem tree. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) Init(c *Connection, parent *ControlFD, name string, mode linux.FileMode, impl ControlFDImpl) { + // Initialize fd with 1 ref which is transferred to c via c.insertFD(). + fd.controlFDRefs.InitRefs() + fd.conn = c + fd.id = c.insertFD(fd) + fd.name = name + fd.ftype = mode.FileType() + fd.impl = impl + fd.setParentLocked(parent) +} + +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) setParentLocked(parent *ControlFD) { + fd.parent = parent + if parent != nil { + parent.IncRef() // Hold a ref on parent. + parent.childrenMu.Lock() + parent.children.PushBack(fd) + parent.childrenMu.Unlock() + } +} + +// FileType returns the file mode only containing the file type bits. +func (fd *ControlFD) FileType() linux.FileMode { + return fd.ftype +} + +// IsDir indicates whether fd represents a directory. +func (fd *ControlFD) IsDir() bool { + return fd.ftype == unix.S_IFDIR +} + +// IsRegular indicates whether fd represents a regular file. +func (fd *ControlFD) IsRegular() bool { + return fd.ftype == unix.S_IFREG +} + +// IsSymlink indicates whether fd represents a symbolic link. +func (fd *ControlFD) IsSymlink() bool { + return fd.ftype == unix.S_IFLNK +} + +// IsSocket indicates whether fd represents a socket. +func (fd *ControlFD) IsSocket() bool { + return fd.ftype == unix.S_IFSOCK +} + +// NameLocked returns the backing file's last component name. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) NameLocked() string { + return fd.name +} + +// ParentLocked returns the parent control FD. Note that parent might be a +// control FD from another connection on this server. So its ID must not +// returned on this connection because FDIDs are local to their connection. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) ParentLocked() ControlFDImpl { + if fd.parent == nil { + return nil + } + return fd.parent.impl +} + +// ID returns fd's ID. +func (fd *ControlFD) ID() FDID { + return fd.id +} + +// FilePath returns the absolute path of the file fd was opened on. This is +// expensive and must not be called on hot paths. FilePath acquires the rename +// mutex for reading so callers should not be holding it. +func (fd *ControlFD) FilePath() string { + // Lock the rename mutex for reading to ensure that the filesystem tree is not + // changed while we traverse it upwards. + fd.conn.server.RenameMu.RLock() + defer fd.conn.server.RenameMu.RUnlock() + return fd.FilePathLocked() +} + +// FilePathLocked is the same as FilePath with the additional precondition. +// +// Precondition: server's rename mutex must be at least read locked. +func (fd *ControlFD) FilePathLocked() string { + // Walk upwards and prepend name to res. + var res fspath.Builder + for fd != nil { + res.PrependComponent(fd.name) + fd = fd.parent + } + return res.String() +} + +// ForEachOpenFD executes fn on each FD opened on fd. +func (fd *ControlFD) ForEachOpenFD(fn func(ofd OpenFDImpl)) { + fd.openFDsMu.RLock() + defer fd.openFDsMu.RUnlock() + for ofd := fd.openFDs.Front(); ofd != nil; ofd = ofd.Next() { + fn(ofd.impl) + } +} + +// OpenFD represents an open file descriptor on the protocol. It resonates +// closely with a Linux file descriptor. Its operations are limited to the +// file. Its operations are not allowed to modify or traverse the filesystem +// tree. See OpenFDImpl for the supported operations. +// +// Reference Model: +// * An OpenFD takes a reference on the control FD it was opened on. +type OpenFD struct { + openFDRefs + openFDEntry + + // All the following fields are immutable. + + // controlFD is the ControlFD on which this FD was opened. OpenFD holds a ref + // on controlFD for its entire lifetime. + controlFD *ControlFD + + // id is the unique FD identifier which identifies this FD on its connection. + id FDID + + // Access mode for this FD. + readable bool + writable bool + + // impl is the open FD implementation which embeds this struct. It + // contains all the implementation specific details. + impl OpenFDImpl +} + +var _ genericFD = (*OpenFD)(nil) + +// ID returns fd's ID. +func (fd *OpenFD) ID() FDID { + return fd.id +} + +// ControlFD returns the control FD on which this FD was opened. +func (fd *OpenFD) ControlFD() ControlFDImpl { + return fd.controlFD.impl +} + +// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context +// parameter should never be used. It exists solely to comply with the +// refsvfs2.RefCounter interface. +func (fd *OpenFD) DecRef(context.Context) { + fd.openFDRefs.DecRef(func() { + fd.controlFD.openFDsMu.Lock() + fd.controlFD.openFDs.Remove(fd) + fd.controlFD.openFDsMu.Unlock() + fd.controlFD.DecRef(nil) // Drop the ref on the control FD. + fd.impl.Close(fd.controlFD.conn) + }) +} + +// Init must be called before first use of fd. +func (fd *OpenFD) Init(cfd *ControlFD, flags uint32, impl OpenFDImpl) { + // Initialize fd with 1 ref which is transferred to c via c.insertFD(). + fd.openFDRefs.InitRefs() + fd.controlFD = cfd + fd.id = cfd.conn.insertFD(fd) + accessMode := flags & unix.O_ACCMODE + fd.readable = accessMode == unix.O_RDONLY || accessMode == unix.O_RDWR + fd.writable = accessMode == unix.O_WRONLY || accessMode == unix.O_RDWR + fd.impl = impl + cfd.IncRef() // Holds a ref on cfd for its lifetime. + cfd.openFDsMu.Lock() + cfd.openFDs.PushBack(fd) + cfd.openFDsMu.Unlock() +} + +// ControlFDImpl contains implementation details for a ControlFD. +// Implementations of ControlFDImpl should contain their associated ControlFD +// by value as their first field. +// +// The operations that perform path traversal or any modification to the +// filesystem tree must synchronize those modifications with the server's +// rename mutex. +type ControlFDImpl interface { + FD() *ControlFD + Close(c *Connection) + Stat(c *Connection, comm Communicator) (uint32, error) + SetStat(c *Connection, comm Communicator, stat SetStatReq) (uint32, error) + Walk(c *Connection, comm Communicator, path StringArray) (uint32, error) + WalkStat(c *Connection, comm Communicator, path StringArray) (uint32, error) + Open(c *Connection, comm Communicator, flags uint32) (uint32, error) + OpenCreate(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, flags uint32) (uint32, error) + Mkdir(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string) (uint32, error) + Mknod(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, minor uint32, major uint32) (uint32, error) + Symlink(c *Connection, comm Communicator, name string, target string, uid UID, gid GID) (uint32, error) + Link(c *Connection, comm Communicator, dir ControlFDImpl, name string) (uint32, error) + StatFS(c *Connection, comm Communicator) (uint32, error) + Readlink(c *Connection, comm Communicator) (uint32, error) + Connect(c *Connection, comm Communicator, sockType uint32) error + Unlink(c *Connection, name string, flags uint32) error + RenameLocked(c *Connection, newDir ControlFDImpl, newName string) (func(ControlFDImpl), func(), error) + GetXattr(c *Connection, comm Communicator, name string, size uint32) (uint32, error) + SetXattr(c *Connection, name string, value string, flags uint32) error + ListXattr(c *Connection, comm Communicator, size uint64) (uint32, error) + RemoveXattr(c *Connection, comm Communicator, name string) error +} + +// OpenFDImpl contains implementation details for a OpenFD. Implementations of +// OpenFDImpl should contain their associated OpenFD by value as their first +// field. +// +// Since these operations do not perform any path traversal or any modification +// to the filesystem tree, there is no need to synchronize with rename +// operations. +type OpenFDImpl interface { + FD() *OpenFD + Close(c *Connection) + Stat(c *Connection, comm Communicator) (uint32, error) + Sync(c *Connection) error + Write(c *Connection, comm Communicator, buf []byte, off uint64) (uint32, error) + Read(c *Connection, comm Communicator, off uint64, count uint32) (uint32, error) + Allocate(c *Connection, mode, off, length uint64) error + Flush(c *Connection) error + Getdent64(c *Connection, comm Communicator, count uint32, seek0 bool) (uint32, error) +} diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go new file mode 100644 index 000000000..82807734d --- /dev/null +++ b/pkg/lisafs/handlers.go @@ -0,0 +1,768 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "fmt" + "path" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +const ( + allowedOpenFlags = unix.O_ACCMODE | unix.O_TRUNC + setStatSupportedMask = unix.STATX_MODE | unix.STATX_UID | unix.STATX_GID | unix.STATX_SIZE | unix.STATX_ATIME | unix.STATX_MTIME +) + +// RPCHandler defines a handler that is invoked when the associated message is +// received. The handler is responsible for: +// +// * Unmarshalling the request from the passed payload and interpreting it. +// * Marshalling the response into the communicator's payload buffer. +// * Return the number of payload bytes written. +// * Donate any FDs (if needed) to comm which will in turn donate it to client. +type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) + +var handlers = [...]RPCHandler{ + Error: ErrorHandler, + Mount: MountHandler, + Channel: ChannelHandler, + FStat: FStatHandler, + SetStat: SetStatHandler, + Walk: WalkHandler, + WalkStat: WalkStatHandler, + OpenAt: OpenAtHandler, + OpenCreateAt: OpenCreateAtHandler, + Close: CloseHandler, + FSync: FSyncHandler, + PWrite: PWriteHandler, + PRead: PReadHandler, + MkdirAt: MkdirAtHandler, + MknodAt: MknodAtHandler, + SymlinkAt: SymlinkAtHandler, + LinkAt: LinkAtHandler, + FStatFS: FStatFSHandler, + FAllocate: FAllocateHandler, + ReadLinkAt: ReadLinkAtHandler, + Flush: FlushHandler, + Connect: ConnectHandler, + UnlinkAt: UnlinkAtHandler, + RenameAt: RenameAtHandler, + Getdents64: Getdents64Handler, + FGetXattr: FGetXattrHandler, + FSetXattr: FSetXattrHandler, + FListXattr: FListXattrHandler, + FRemoveXattr: FRemoveXattrHandler, +} + +// ErrorHandler handles Error message. +func ErrorHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + // Client should never send Error. + return 0, unix.EINVAL +} + +// MountHandler handles the Mount RPC. Note that there can not be concurrent +// executions of MountHandler on a connection because the connection enforces +// that Mount is the first message on the connection. Only after the connection +// has been successfully mounted can other channels be created. +func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req MountReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + mountPath := path.Clean(string(req.MountPath)) + if !filepath.IsAbs(mountPath) { + log.Warningf("mountPath %q is not absolute", mountPath) + return 0, unix.EINVAL + } + + if c.mounted { + log.Warningf("connection has already been mounted at %q", mountPath) + return 0, unix.EBUSY + } + + rootFD, rootIno, err := c.ServerImpl().Mount(c, mountPath) + if err != nil { + return 0, err + } + + c.server.addMountPoint(rootFD.FD()) + c.mounted = true + resp := MountResp{ + Root: rootIno, + SupportedMs: c.ServerImpl().SupportedMessages(), + MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()), + } + respPayloadLen := uint32(resp.SizeBytes()) + resp.MarshalBytes(comm.PayloadBuf(respPayloadLen)) + return respPayloadLen, nil +} + +// ChannelHandler handles the Channel RPC. +func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize()) + if err != nil { + return 0, err + } + + // Start servicing the channel in a separate goroutine. + c.activeWg.Add(1) + go func() { + if err := c.service(ch); err != nil { + // Don't log shutdown error which is expected during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("lisafs.Connection.service(channel = @%p): %v", ch, err) + } + } + c.activeWg.Done() + }() + + clientDataFD, err := unix.Dup(desc.FD) + if err != nil { + unix.Close(fdSock) + ch.shutdown() + return 0, err + } + + // Respond to client with successful channel creation message. + if err := comm.DonateFD(clientDataFD); err != nil { + return 0, err + } + if err := comm.DonateFD(fdSock); err != nil { + return 0, err + } + resp := ChannelResp{ + dataOffset: desc.Offset, + dataLength: uint64(desc.Length), + } + respLen := uint32(resp.SizeBytes()) + resp.MarshalUnsafe(comm.PayloadBuf(respLen)) + return respLen, nil +} + +// FStatHandler handles the FStat RPC. +func FStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req StatReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.lookupFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + switch t := fd.(type) { + case *ControlFD: + return t.impl.Stat(c, comm) + case *OpenFD: + return t.impl.Stat(c, comm) + default: + panic(fmt.Sprintf("unknown fd type %T", t)) + } +} + +// SetStatHandler handles the SetStat RPC. +func SetStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + + var req SetStatReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + if req.Mask&^setStatSupportedMask != 0 { + return 0, unix.EPERM + } + + return fd.impl.SetStat(c, comm, req) +} + +// WalkHandler handles the Walk RPC. +func WalkHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req WalkReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + for _, name := range req.Path { + if err := checkSafeName(name); err != nil { + return 0, err + } + } + + return fd.impl.Walk(c, comm, req.Path) +} + +// WalkStatHandler handles the WalkStat RPC. +func WalkStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req WalkReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + // Note that this fd is allowed to not actually be a directory when the + // only path component to walk is "" (self). + if !fd.IsDir() { + if len(req.Path) > 1 || (len(req.Path) == 1 && len(req.Path[0]) > 0) { + return 0, unix.ENOTDIR + } + } + for i, name := range req.Path { + // First component is allowed to be "". + if i == 0 && len(name) == 0 { + continue + } + if err := checkSafeName(name); err != nil { + return 0, err + } + } + + return fd.impl.WalkStat(c, comm, req.Path) +} + +// OpenAtHandler handles the OpenAt RPC. +func OpenAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req OpenAtReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + // Only keep allowed open flags. + if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags { + log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags) + req.Flags = allowedFlags + } + + accessMode := req.Flags & unix.O_ACCMODE + trunc := req.Flags&unix.O_TRUNC != 0 + if c.readonly && (accessMode != unix.O_RDONLY || trunc) { + return 0, unix.EROFS + } + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if fd.IsDir() { + // Directory is not truncatable and must be opened with O_RDONLY. + if accessMode != unix.O_RDONLY || trunc { + return 0, unix.EISDIR + } + } + + return fd.impl.Open(c, comm, req.Flags) +} + +// OpenCreateAtHandler handles the OpenCreateAt RPC. +func OpenCreateAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req OpenCreateAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Only keep allowed open flags. + if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags { + log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags) + req.Flags = allowedFlags + } + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + + return fd.impl.OpenCreate(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Flags)) +} + +// CloseHandler handles the Close RPC. +func CloseHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req CloseReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + for _, fd := range req.FDs { + c.RemoveFD(fd) + } + + // There is no response message for this. + return 0, nil +} + +// FSyncHandler handles the FSync RPC. +func FSyncHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FsyncReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + // Return the first error we encounter, but sync everything we can + // regardless. + var retErr error + for _, fdid := range req.FDs { + if err := c.fsyncFD(fdid); err != nil && retErr == nil { + retErr = err + } + } + + // There is no response message for this. + return 0, retErr +} + +func (c *Connection) fsyncFD(id FDID) error { + fd, err := c.LookupOpenFD(id) + if err != nil { + return err + } + return fd.impl.Sync(c) +} + +// PWriteHandler handles the PWrite RPC. +func PWriteHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req PWriteReq + // Note that it is an optimized Unmarshal operation which avoids any buffer + // allocation and copying. req.Buf just points to payload. This is safe to do + // as the handler owns payload and req's lifetime is limited to the handler. + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + if !fd.writable { + return 0, unix.EBADF + } + return fd.impl.Write(c, comm, req.Buf, uint64(req.Offset)) +} + +// PReadHandler handles the PRead RPC. +func PReadHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req PReadReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.readable { + return 0, unix.EBADF + } + return fd.impl.Read(c, comm, req.Offset, req.Count) +} + +// MkdirAtHandler handles the MkdirAt RPC. +func MkdirAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req MkdirAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Mkdir(c, comm, req.Mode, req.UID, req.GID, name) +} + +// MknodAtHandler handles the MknodAt RPC. +func MknodAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req MknodAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Mknod(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Minor), uint32(req.Major)) +} + +// SymlinkAtHandler handles the SymlinkAt RPC. +func SymlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req SymlinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return fd.impl.Symlink(c, comm, name, string(req.Target), req.UID, req.GID) +} + +// LinkAtHandler handles the LinkAt RPC. +func LinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req LinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + + targetFD, err := c.LookupControlFD(req.Target) + if err != nil { + return 0, err + } + return targetFD.impl.Link(c, comm, fd.impl, name) +} + +// FStatFSHandler handles the FStatFS RPC. +func FStatFSHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FStatFSReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.StatFS(c, comm) +} + +// FAllocateHandler handles the FAllocate RPC. +func FAllocateHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FAllocateReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.writable { + return 0, unix.EBADF + } + return 0, fd.impl.Allocate(c, req.Mode, req.Offset, req.Length) +} + +// ReadLinkAtHandler handles the ReadLinkAt RPC. +func ReadLinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req ReadLinkAtReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsSymlink() { + return 0, unix.EINVAL + } + return fd.impl.Readlink(c, comm) +} + +// FlushHandler handles the Flush RPC. +func FlushHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FlushReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + + return 0, fd.impl.Flush(c) +} + +// ConnectHandler handles the Connect RPC. +func ConnectHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req ConnectReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsSocket() { + return 0, unix.ENOTSOCK + } + return 0, fd.impl.Connect(c, comm, req.SockType) +} + +// UnlinkAtHandler handles the UnlinkAt RPC. +func UnlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req UnlinkAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + name := string(req.Name) + if err := checkSafeName(name); err != nil { + return 0, err + } + + fd, err := c.LookupControlFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsDir() { + return 0, unix.ENOTDIR + } + return 0, fd.impl.Unlink(c, name, uint32(req.Flags)) +} + +// RenameAtHandler handles the RenameAt RPC. +func RenameAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req RenameAtReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + newName := string(req.NewName) + if err := checkSafeName(newName); err != nil { + return 0, err + } + + renamed, err := c.LookupControlFD(req.Renamed) + if err != nil { + return 0, err + } + defer renamed.DecRef(nil) + + newDir, err := c.LookupControlFD(req.NewDir) + if err != nil { + return 0, err + } + defer newDir.DecRef(nil) + if !newDir.IsDir() { + return 0, unix.ENOTDIR + } + + // Hold RenameMu for writing during rename, this is important. + c.server.RenameMu.Lock() + defer c.server.RenameMu.Unlock() + + if renamed.parent == nil { + // renamed is root. + return 0, unix.EBUSY + } + + oldParentPath := renamed.parent.FilePathLocked() + oldPath := oldParentPath + "/" + renamed.name + if newName == renamed.name && oldParentPath == newDir.FilePathLocked() { + // Nothing to do. + return 0, nil + } + + updateControlFD, cleanUp, err := renamed.impl.RenameLocked(c, newDir.impl, newName) + if err != nil { + return 0, err + } + + c.server.forEachMountPoint(func(root *ControlFD) { + if !strings.HasPrefix(oldPath, root.name) { + return + } + pit := fspath.Parse(oldPath[len(root.name):]).Begin + root.renameRecursiveLocked(newDir, newName, pit, updateControlFD) + }) + + if cleanUp != nil { + cleanUp() + } + return 0, nil +} + +// Precondition: rename mutex must be locked for writing. +func (fd *ControlFD) renameRecursiveLocked(newDir *ControlFD, newName string, pit fspath.Iterator, updateControlFD func(ControlFDImpl)) { + if !pit.Ok() { + // fd should be renamed. + fd.clearParentLocked() + fd.setParentLocked(newDir) + fd.name = newName + if updateControlFD != nil { + updateControlFD(fd.impl) + } + return + } + + cur := pit.String() + next := pit.Next() + // No need to hold fd.childrenMu because RenameMu is locked for writing. + for child := fd.children.Front(); child != nil; child = child.Next() { + if child.name == cur { + child.renameRecursiveLocked(newDir, newName, next, updateControlFD) + } + } +} + +// Getdents64Handler handles the Getdents64 RPC. +func Getdents64Handler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req Getdents64Req + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupOpenFD(req.DirFD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.controlFD.IsDir() { + return 0, unix.ENOTDIR + } + + seek0 := false + if req.Count < 0 { + seek0 = true + req.Count = -req.Count + } + return fd.impl.Getdent64(c, comm, uint32(req.Count), seek0) +} + +// FGetXattrHandler handles the FGetXattr RPC. +func FGetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FGetXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.GetXattr(c, comm, string(req.Name), uint32(req.BufSize)) +} + +// FSetXattrHandler handles the FSetXattr RPC. +func FSetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FSetXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return 0, fd.impl.SetXattr(c, string(req.Name), string(req.Value), uint32(req.Flags)) +} + +// FListXattrHandler handles the FListXattr RPC. +func FListXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req FListXattrReq + req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return fd.impl.ListXattr(c, comm, req.Size) +} + +// FRemoveXattrHandler handles the FRemoveXattr RPC. +func FRemoveXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + if c.readonly { + return 0, unix.EROFS + } + var req FRemoveXattrReq + req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) + + fd, err := c.LookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + return 0, fd.impl.RemoveXattr(c, comm, string(req.Name)) +} + +// checkSafeName validates the name and returns nil or returns an error. +func checkSafeName(name string) error { + if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." { + return nil + } + return unix.EINVAL +} diff --git a/pkg/lisafs/lisafs.go b/pkg/lisafs/lisafs.go new file mode 100644 index 000000000..4d8e956ab --- /dev/null +++ b/pkg/lisafs/lisafs.go @@ -0,0 +1,18 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package lisafs (LInux SAndbox FileSystem) defines the protocol for +// filesystem RPCs between an untrusted Sandbox (client) and a trusted +// filesystem server. +package lisafs diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go new file mode 100644 index 000000000..722afd0be --- /dev/null +++ b/pkg/lisafs/message.go @@ -0,0 +1,1251 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math" + "os" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// Messages have two parts: +// * A transport header used to decipher received messages. +// * A byte array referred to as "payload" which contains the actual message. +// +// "dataLen" refers to the size of both combined. + +// MID (message ID) is used to identify messages to parse from payload. +// +// +marshal slice:MIDSlice +type MID uint16 + +// These constants are used to identify their corresponding message types. +const ( + // Error is only used in responses to pass errors to client. + Error MID = 0 + + // Mount is used to establish connection between the client and server mount + // point. lisafs requires that the client makes a successful Mount RPC before + // making other RPCs. + Mount MID = 1 + + // Channel requests to start a new communicational channel. + Channel MID = 2 + + // FStat requests the stat(2) results for a specified file. + FStat MID = 3 + + // SetStat requests to change file attributes. Note that there is no one + // corresponding Linux syscall. This is a conglomeration of fchmod(2), + // fchown(2), ftruncate(2) and futimesat(2). + SetStat MID = 4 + + // Walk requests to walk the specified path starting from the specified + // directory. Server-side path traversal is terminated preemptively on + // symlinks entries because they can cause non-linear traversal. + Walk MID = 5 + + // WalkStat is the same as Walk, except the following differences: + // * If the first path component is "", then it also returns stat results + // for the directory where the walk starts. + // * Does not return Inode, just the Stat results for each path component. + WalkStat MID = 6 + + // OpenAt is analogous to openat(2). It does not perform any walk. It merely + // duplicates the control FD with the open flags passed. + OpenAt MID = 7 + + // OpenCreateAt is analogous to openat(2) with O_CREAT|O_EXCL added to flags. + // It also returns the newly created file inode. + OpenCreateAt MID = 8 + + // Close is analogous to close(2) but can work on multiple FDs. + Close MID = 9 + + // FSync is analogous to fsync(2) but can work on multiple FDs. + FSync MID = 10 + + // PWrite is analogous to pwrite(2). + PWrite MID = 11 + + // PRead is analogous to pread(2). + PRead MID = 12 + + // MkdirAt is analogous to mkdirat(2). + MkdirAt MID = 13 + + // MknodAt is analogous to mknodat(2). + MknodAt MID = 14 + + // SymlinkAt is analogous to symlinkat(2). + SymlinkAt MID = 15 + + // LinkAt is analogous to linkat(2). + LinkAt MID = 16 + + // FStatFS is analogous to fstatfs(2). + FStatFS MID = 17 + + // FAllocate is analogous to fallocate(2). + FAllocate MID = 18 + + // ReadLinkAt is analogous to readlinkat(2). + ReadLinkAt MID = 19 + + // Flush cleans up the file state. Its behavior is implementation + // dependent and might not even be supported in server implementations. + Flush MID = 20 + + // Connect is loosely analogous to connect(2). + Connect MID = 21 + + // UnlinkAt is analogous to unlinkat(2). + UnlinkAt MID = 22 + + // RenameAt is loosely analogous to renameat(2). + RenameAt MID = 23 + + // Getdents64 is analogous to getdents64(2). + Getdents64 MID = 24 + + // FGetXattr is analogous to fgetxattr(2). + FGetXattr MID = 25 + + // FSetXattr is analogous to fsetxattr(2). + FSetXattr MID = 26 + + // FListXattr is analogous to flistxattr(2). + FListXattr MID = 27 + + // FRemoveXattr is analogous to fremovexattr(2). + FRemoveXattr MID = 28 +) + +const ( + // NoUID is a sentinel used to indicate no valid UID. + NoUID UID = math.MaxUint32 + + // NoGID is a sentinel used to indicate no valid GID. + NoGID GID = math.MaxUint32 +) + +// MaxMessageSize is the recommended max message size that can be used by +// connections. Server implementations may choose to use other values. +func MaxMessageSize() uint32 { + // Return HugePageSize - PageSize so that when flipcall packet window is + // created with MaxMessageSize() + flipcall header size + channel header + // size, HugePageSize is allocated and can be backed by a single huge page + // if supported by the underlying memfd. + return uint32(hostarch.HugePageSize - os.Getpagesize()) +} + +// TODO(gvisor.dev/issue/6450): Once this is resolved: +// * Update manual implementations and function signatures. +// * Update RPC handlers and appropriate callers to handle errors correctly. +// * Update manual implementations to get rid of buffer shifting. + +// UID represents a user ID. +// +// +marshal +type UID uint32 + +// Ok returns true if uid is not NoUID. +func (uid UID) Ok() bool { + return uid != NoUID +} + +// GID represents a group ID. +// +// +marshal +type GID uint32 + +// Ok returns true if gid is not NoGID. +func (gid GID) Ok() bool { + return gid != NoGID +} + +// NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes. +func NoopMarshal([]byte) {} + +// NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes. +func NoopUnmarshal([]byte) {} + +// SizedString represents a string in memory. The marshalled string bytes are +// preceded by a uint32 signifying the string length. +type SizedString string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SizedString) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + len(*s) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SizedString) MarshalBytes(dst []byte) { + strLen := primitive.Uint32(len(*s)) + strLen.MarshalUnsafe(dst) + dst = dst[strLen.SizeBytes():] + // Copy without any allocation. + copy(dst[:strLen], *s) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SizedString) UnmarshalBytes(src []byte) { + var strLen primitive.Uint32 + strLen.UnmarshalUnsafe(src) + src = src[strLen.SizeBytes():] + // Take the hit, this leads to an allocation + memcpy. No way around it. + *s = SizedString(src[:strLen]) +} + +// StringArray represents an array of SizedStrings in memory. The marshalled +// array data is preceded by a uint32 signifying the array length. +type StringArray []string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *StringArray) SizeBytes() int { + size := (*primitive.Uint32)(nil).SizeBytes() + for _, str := range *s { + sstr := SizedString(str) + size += sstr.SizeBytes() + } + return size +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *StringArray) MarshalBytes(dst []byte) { + arrLen := primitive.Uint32(len(*s)) + arrLen.MarshalUnsafe(dst) + dst = dst[arrLen.SizeBytes():] + for _, str := range *s { + sstr := SizedString(str) + sstr.MarshalBytes(dst) + dst = dst[sstr.SizeBytes():] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *StringArray) UnmarshalBytes(src []byte) { + var arrLen primitive.Uint32 + arrLen.UnmarshalUnsafe(src) + src = src[arrLen.SizeBytes():] + + if cap(*s) < int(arrLen) { + *s = make([]string, arrLen) + } else { + *s = (*s)[:arrLen] + } + + for i := primitive.Uint32(0); i < arrLen; i++ { + var sstr SizedString + sstr.UnmarshalBytes(src) + src = src[sstr.SizeBytes():] + (*s)[i] = string(sstr) + } +} + +// Inode represents an inode on the remote filesystem. +// +// +marshal slice:InodeSlice +type Inode struct { + ControlFD FDID + _ uint32 // Need to make struct packed. + Stat linux.Statx +} + +// MountReq represents a Mount request. +type MountReq struct { + MountPath SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MountReq) SizeBytes() int { + return m.MountPath.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MountReq) MarshalBytes(dst []byte) { + m.MountPath.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MountReq) UnmarshalBytes(src []byte) { + m.MountPath.UnmarshalBytes(src) +} + +// MountResp represents a Mount response. +type MountResp struct { + Root Inode + // MaxMessageSize is the maximum size of messages communicated between the + // client and server in bytes. This includes the communication header. + MaxMessageSize primitive.Uint32 + // SupportedMs holds all the supported messages. + SupportedMs []MID +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MountResp) SizeBytes() int { + return m.Root.SizeBytes() + + m.MaxMessageSize.SizeBytes() + + (*primitive.Uint16)(nil).SizeBytes() + + (len(m.SupportedMs) * (*MID)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MountResp) MarshalBytes(dst []byte) { + m.Root.MarshalUnsafe(dst) + dst = dst[m.Root.SizeBytes():] + m.MaxMessageSize.MarshalUnsafe(dst) + dst = dst[m.MaxMessageSize.SizeBytes():] + numSupported := primitive.Uint16(len(m.SupportedMs)) + numSupported.MarshalBytes(dst) + dst = dst[numSupported.SizeBytes():] + MarshalUnsafeMIDSlice(m.SupportedMs, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MountResp) UnmarshalBytes(src []byte) { + m.Root.UnmarshalUnsafe(src) + src = src[m.Root.SizeBytes():] + m.MaxMessageSize.UnmarshalUnsafe(src) + src = src[m.MaxMessageSize.SizeBytes():] + var numSupported primitive.Uint16 + numSupported.UnmarshalBytes(src) + src = src[numSupported.SizeBytes():] + m.SupportedMs = make([]MID, numSupported) + UnmarshalUnsafeMIDSlice(m.SupportedMs, src) +} + +// ChannelResp is the response to the create channel request. +// +// +marshal +type ChannelResp struct { + dataOffset int64 + dataLength uint64 +} + +// ErrorResp is returned to represent an error while handling a request. +// +// +marshal +type ErrorResp struct { + errno uint32 +} + +// StatReq requests the stat results for the specified FD. +// +// +marshal +type StatReq struct { + FD FDID +} + +// SetStatReq is used to set attributeds on FDs. +// +// +marshal +type SetStatReq struct { + FD FDID + _ uint32 + Mask uint32 + Mode uint32 // Only permissions part is settable. + UID UID + GID GID + Size uint64 + Atime linux.Timespec + Mtime linux.Timespec +} + +// SetStatResp is used to communicate SetStat results. It contains a mask +// representing the failed changes. It also contains the errno of the failed +// set attribute operation. If multiple operations failed then any of those +// errnos can be returned. +// +// +marshal +type SetStatResp struct { + FailureMask uint32 + FailureErrNo uint32 +} + +// WalkReq is used to request to walk multiple path components at once. This +// is used for both Walk and WalkStat. +type WalkReq struct { + DirFD FDID + Path StringArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkReq) SizeBytes() int { + return w.DirFD.SizeBytes() + w.Path.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkReq) MarshalBytes(dst []byte) { + w.DirFD.MarshalUnsafe(dst) + dst = dst[w.DirFD.SizeBytes():] + w.Path.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkReq) UnmarshalBytes(src []byte) { + w.DirFD.UnmarshalUnsafe(src) + src = src[w.DirFD.SizeBytes():] + w.Path.UnmarshalBytes(src) +} + +// WalkStatus is used to indicate the reason for partial/unsuccessful server +// side Walk operations. Please note that partial/unsuccessful walk operations +// do not necessarily fail the RPC. The RPC is successful with a failure hint +// which can be used by the client to infer server-side state. +type WalkStatus = primitive.Uint8 + +const ( + // WalkSuccess indicates that all path components were successfully walked. + WalkSuccess WalkStatus = iota + + // WalkComponentDoesNotExist indicates that the walk was prematurely + // terminated because an intermediate path component does not exist on + // server. The results of all previous existing path components is returned. + WalkComponentDoesNotExist + + // WalkComponentSymlink indicates that the walk was prematurely + // terminated because an intermediate path component was a symlink. It is not + // safe to resolve symlinks remotely (unaware of mount points). + WalkComponentSymlink +) + +// WalkResp is used to communicate the inodes walked by the server. +type WalkResp struct { + Status WalkStatus + Inodes []Inode +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkResp) SizeBytes() int { + return w.Status.SizeBytes() + + (*primitive.Uint32)(nil).SizeBytes() + (len(w.Inodes) * (*Inode)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkResp) MarshalBytes(dst []byte) { + w.Status.MarshalUnsafe(dst) + dst = dst[w.Status.SizeBytes():] + + numInodes := primitive.Uint32(len(w.Inodes)) + numInodes.MarshalUnsafe(dst) + dst = dst[numInodes.SizeBytes():] + + MarshalUnsafeInodeSlice(w.Inodes, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkResp) UnmarshalBytes(src []byte) { + w.Status.UnmarshalUnsafe(src) + src = src[w.Status.SizeBytes():] + + var numInodes primitive.Uint32 + numInodes.UnmarshalUnsafe(src) + src = src[numInodes.SizeBytes():] + + if cap(w.Inodes) < int(numInodes) { + w.Inodes = make([]Inode, numInodes) + } else { + w.Inodes = w.Inodes[:numInodes] + } + UnmarshalUnsafeInodeSlice(w.Inodes, src) +} + +// WalkStatResp is used to communicate stat results for WalkStat. +type WalkStatResp struct { + Stats []linux.Statx +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WalkStatResp) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (len(w.Stats) * linux.SizeOfStatx) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WalkStatResp) MarshalBytes(dst []byte) { + numStats := primitive.Uint32(len(w.Stats)) + numStats.MarshalUnsafe(dst) + dst = dst[numStats.SizeBytes():] + + linux.MarshalUnsafeStatxSlice(w.Stats, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WalkStatResp) UnmarshalBytes(src []byte) { + var numStats primitive.Uint32 + numStats.UnmarshalUnsafe(src) + src = src[numStats.SizeBytes():] + + if cap(w.Stats) < int(numStats) { + w.Stats = make([]linux.Statx, numStats) + } else { + w.Stats = w.Stats[:numStats] + } + linux.UnmarshalUnsafeStatxSlice(w.Stats, src) +} + +// OpenAtReq is used to open existing FDs with the specified flags. +// +// +marshal +type OpenAtReq struct { + FD FDID + Flags uint32 +} + +// OpenAtResp is used to communicate the newly created FD. +// +// +marshal +type OpenAtResp struct { + NewFD FDID +} + +// +marshal +type createCommon struct { + DirFD FDID + Mode linux.FileMode + _ uint16 // Need to make struct packed. + UID UID + GID GID +} + +// OpenCreateAtReq is used to make OpenCreateAt requests. +type OpenCreateAtReq struct { + createCommon + Name SizedString + Flags primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (o *OpenCreateAtReq) SizeBytes() int { + return o.createCommon.SizeBytes() + o.Name.SizeBytes() + o.Flags.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (o *OpenCreateAtReq) MarshalBytes(dst []byte) { + o.createCommon.MarshalUnsafe(dst) + dst = dst[o.createCommon.SizeBytes():] + o.Name.MarshalBytes(dst) + dst = dst[o.Name.SizeBytes():] + o.Flags.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) { + o.createCommon.UnmarshalUnsafe(src) + src = src[o.createCommon.SizeBytes():] + o.Name.UnmarshalBytes(src) + src = src[o.Name.SizeBytes():] + o.Flags.UnmarshalUnsafe(src) +} + +// OpenCreateAtResp is used to communicate successful OpenCreateAt results. +// +// +marshal +type OpenCreateAtResp struct { + Child Inode + NewFD FDID + _ uint32 // Need to make struct packed. +} + +// FdArray is a utility struct which implements a marshallable type for +// communicating an array of FDIDs. In memory, the array data is preceded by a +// uint32 denoting the array length. +type FdArray []FDID + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FdArray) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (len(*f) * (*FDID)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FdArray) MarshalBytes(dst []byte) { + arrLen := primitive.Uint32(len(*f)) + arrLen.MarshalUnsafe(dst) + dst = dst[arrLen.SizeBytes():] + MarshalUnsafeFDIDSlice(*f, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FdArray) UnmarshalBytes(src []byte) { + var arrLen primitive.Uint32 + arrLen.UnmarshalUnsafe(src) + src = src[arrLen.SizeBytes():] + if cap(*f) < int(arrLen) { + *f = make(FdArray, arrLen) + } else { + *f = (*f)[:arrLen] + } + UnmarshalUnsafeFDIDSlice(*f, src) +} + +// CloseReq is used to close(2) FDs. +type CloseReq struct { + FDs FdArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (c *CloseReq) SizeBytes() int { + return c.FDs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *CloseReq) MarshalBytes(dst []byte) { + c.FDs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *CloseReq) UnmarshalBytes(src []byte) { + c.FDs.UnmarshalBytes(src) +} + +// FsyncReq is used to fsync(2) FDs. +type FsyncReq struct { + FDs FdArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FsyncReq) SizeBytes() int { + return f.FDs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FsyncReq) MarshalBytes(dst []byte) { + f.FDs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FsyncReq) UnmarshalBytes(src []byte) { + f.FDs.UnmarshalBytes(src) +} + +// PReadReq is used to pread(2) on an FD. +// +// +marshal +type PReadReq struct { + Offset uint64 + FD FDID + Count uint32 +} + +// PReadResp is used to return the result of pread(2). +type PReadResp struct { + NumBytes primitive.Uint32 + Buf []byte +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *PReadResp) SizeBytes() int { + return r.NumBytes.SizeBytes() + int(r.NumBytes) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *PReadResp) MarshalBytes(dst []byte) { + r.NumBytes.MarshalUnsafe(dst) + dst = dst[r.NumBytes.SizeBytes():] + copy(dst[:r.NumBytes], r.Buf[:r.NumBytes]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *PReadResp) UnmarshalBytes(src []byte) { + r.NumBytes.UnmarshalUnsafe(src) + src = src[r.NumBytes.SizeBytes():] + + // We expect the client to have already allocated r.Buf. r.Buf probably + // (optimally) points to usermem. Directly copy into that. + copy(r.Buf[:r.NumBytes], src[:r.NumBytes]) +} + +// PWriteReq is used to pwrite(2) on an FD. +type PWriteReq struct { + Offset primitive.Uint64 + FD FDID + NumBytes primitive.Uint32 + Buf []byte +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *PWriteReq) SizeBytes() int { + return w.Offset.SizeBytes() + w.FD.SizeBytes() + w.NumBytes.SizeBytes() + int(w.NumBytes) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *PWriteReq) MarshalBytes(dst []byte) { + w.Offset.MarshalUnsafe(dst) + dst = dst[w.Offset.SizeBytes():] + w.FD.MarshalUnsafe(dst) + dst = dst[w.FD.SizeBytes():] + w.NumBytes.MarshalUnsafe(dst) + dst = dst[w.NumBytes.SizeBytes():] + copy(dst[:w.NumBytes], w.Buf[:w.NumBytes]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *PWriteReq) UnmarshalBytes(src []byte) { + w.Offset.UnmarshalUnsafe(src) + src = src[w.Offset.SizeBytes():] + w.FD.UnmarshalUnsafe(src) + src = src[w.FD.SizeBytes():] + w.NumBytes.UnmarshalUnsafe(src) + src = src[w.NumBytes.SizeBytes():] + + // This is an optimization. Assuming that the server is making this call, it + // is safe to just point to src rather than allocating and copying. + w.Buf = src[:w.NumBytes] +} + +// PWriteResp is used to return the result of pwrite(2). +// +// +marshal +type PWriteResp struct { + Count uint64 +} + +// MkdirAtReq is used to make MkdirAt requests. +type MkdirAtReq struct { + createCommon + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MkdirAtReq) SizeBytes() int { + return m.createCommon.SizeBytes() + m.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MkdirAtReq) MarshalBytes(dst []byte) { + m.createCommon.MarshalUnsafe(dst) + dst = dst[m.createCommon.SizeBytes():] + m.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MkdirAtReq) UnmarshalBytes(src []byte) { + m.createCommon.UnmarshalUnsafe(src) + src = src[m.createCommon.SizeBytes():] + m.Name.UnmarshalBytes(src) +} + +// MkdirAtResp is the response to a successful MkdirAt request. +// +// +marshal +type MkdirAtResp struct { + ChildDir Inode +} + +// MknodAtReq is used to make MknodAt requests. +type MknodAtReq struct { + createCommon + Name SizedString + Minor primitive.Uint32 + Major primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MknodAtReq) SizeBytes() int { + return m.createCommon.SizeBytes() + m.Name.SizeBytes() + m.Minor.SizeBytes() + m.Major.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MknodAtReq) MarshalBytes(dst []byte) { + m.createCommon.MarshalUnsafe(dst) + dst = dst[m.createCommon.SizeBytes():] + m.Name.MarshalBytes(dst) + dst = dst[m.Name.SizeBytes():] + m.Minor.MarshalUnsafe(dst) + dst = dst[m.Minor.SizeBytes():] + m.Major.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MknodAtReq) UnmarshalBytes(src []byte) { + m.createCommon.UnmarshalUnsafe(src) + src = src[m.createCommon.SizeBytes():] + m.Name.UnmarshalBytes(src) + src = src[m.Name.SizeBytes():] + m.Minor.UnmarshalUnsafe(src) + src = src[m.Minor.SizeBytes():] + m.Major.UnmarshalUnsafe(src) +} + +// MknodAtResp is the response to a successful MknodAt request. +// +// +marshal +type MknodAtResp struct { + Child Inode +} + +// SymlinkAtReq is used to make SymlinkAt request. +type SymlinkAtReq struct { + DirFD FDID + Name SizedString + Target SizedString + UID UID + GID GID +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SymlinkAtReq) SizeBytes() int { + return s.DirFD.SizeBytes() + s.Name.SizeBytes() + s.Target.SizeBytes() + s.UID.SizeBytes() + s.GID.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SymlinkAtReq) MarshalBytes(dst []byte) { + s.DirFD.MarshalUnsafe(dst) + dst = dst[s.DirFD.SizeBytes():] + s.Name.MarshalBytes(dst) + dst = dst[s.Name.SizeBytes():] + s.Target.MarshalBytes(dst) + dst = dst[s.Target.SizeBytes():] + s.UID.MarshalUnsafe(dst) + dst = dst[s.UID.SizeBytes():] + s.GID.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SymlinkAtReq) UnmarshalBytes(src []byte) { + s.DirFD.UnmarshalUnsafe(src) + src = src[s.DirFD.SizeBytes():] + s.Name.UnmarshalBytes(src) + src = src[s.Name.SizeBytes():] + s.Target.UnmarshalBytes(src) + src = src[s.Target.SizeBytes():] + s.UID.UnmarshalUnsafe(src) + src = src[s.UID.SizeBytes():] + s.GID.UnmarshalUnsafe(src) +} + +// SymlinkAtResp is the response to a successful SymlinkAt request. +// +// +marshal +type SymlinkAtResp struct { + Symlink Inode +} + +// LinkAtReq is used to make LinkAt requests. +type LinkAtReq struct { + DirFD FDID + Target FDID + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (l *LinkAtReq) SizeBytes() int { + return l.DirFD.SizeBytes() + l.Target.SizeBytes() + l.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (l *LinkAtReq) MarshalBytes(dst []byte) { + l.DirFD.MarshalUnsafe(dst) + dst = dst[l.DirFD.SizeBytes():] + l.Target.MarshalUnsafe(dst) + dst = dst[l.Target.SizeBytes():] + l.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (l *LinkAtReq) UnmarshalBytes(src []byte) { + l.DirFD.UnmarshalUnsafe(src) + src = src[l.DirFD.SizeBytes():] + l.Target.UnmarshalUnsafe(src) + src = src[l.Target.SizeBytes():] + l.Name.UnmarshalBytes(src) +} + +// LinkAtResp is used to respond to a successful LinkAt request. +// +// +marshal +type LinkAtResp struct { + Link Inode +} + +// FStatFSReq is used to request StatFS results for the specified FD. +// +// +marshal +type FStatFSReq struct { + FD FDID +} + +// StatFS is responded to a successful FStatFS request. +// +// +marshal +type StatFS struct { + Type uint64 + BlockSize int64 + Blocks uint64 + BlocksFree uint64 + BlocksAvailable uint64 + Files uint64 + FilesFree uint64 + NameLength uint64 +} + +// FAllocateReq is used to request to fallocate(2) an FD. This has no response. +// +// +marshal +type FAllocateReq struct { + FD FDID + _ uint32 + Mode uint64 + Offset uint64 + Length uint64 +} + +// ReadLinkAtReq is used to readlinkat(2) at the specified FD. +// +// +marshal +type ReadLinkAtReq struct { + FD FDID +} + +// ReadLinkAtResp is used to communicate ReadLinkAt results. +type ReadLinkAtResp struct { + Target SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *ReadLinkAtResp) SizeBytes() int { + return r.Target.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *ReadLinkAtResp) MarshalBytes(dst []byte) { + r.Target.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) { + r.Target.UnmarshalBytes(src) +} + +// FlushReq is used to make Flush requests. +// +// +marshal +type FlushReq struct { + FD FDID +} + +// ConnectReq is used to make a Connect request. +// +// +marshal +type ConnectReq struct { + FD FDID + // SockType is used to specify the socket type to connect to. As a special + // case, SockType = 0 means that the socket type does not matter and the + // requester will accept any socket type. + SockType uint32 +} + +// UnlinkAtReq is used to make UnlinkAt request. +type UnlinkAtReq struct { + DirFD FDID + Name SizedString + Flags primitive.Uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *UnlinkAtReq) SizeBytes() int { + return u.DirFD.SizeBytes() + u.Name.SizeBytes() + u.Flags.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *UnlinkAtReq) MarshalBytes(dst []byte) { + u.DirFD.MarshalUnsafe(dst) + dst = dst[u.DirFD.SizeBytes():] + u.Name.MarshalBytes(dst) + dst = dst[u.Name.SizeBytes():] + u.Flags.MarshalUnsafe(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *UnlinkAtReq) UnmarshalBytes(src []byte) { + u.DirFD.UnmarshalUnsafe(src) + src = src[u.DirFD.SizeBytes():] + u.Name.UnmarshalBytes(src) + src = src[u.Name.SizeBytes():] + u.Flags.UnmarshalUnsafe(src) +} + +// RenameAtReq is used to make Rename requests. Note that the request takes in +// the to-be-renamed file's FD instead of oldDir and oldName like renameat(2). +type RenameAtReq struct { + Renamed FDID + NewDir FDID + NewName SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *RenameAtReq) SizeBytes() int { + return r.Renamed.SizeBytes() + r.NewDir.SizeBytes() + r.NewName.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *RenameAtReq) MarshalBytes(dst []byte) { + r.Renamed.MarshalUnsafe(dst) + dst = dst[r.Renamed.SizeBytes():] + r.NewDir.MarshalUnsafe(dst) + dst = dst[r.NewDir.SizeBytes():] + r.NewName.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *RenameAtReq) UnmarshalBytes(src []byte) { + r.Renamed.UnmarshalUnsafe(src) + src = src[r.Renamed.SizeBytes():] + r.NewDir.UnmarshalUnsafe(src) + src = src[r.NewDir.SizeBytes():] + r.NewName.UnmarshalBytes(src) +} + +// Getdents64Req is used to make Getdents64 requests. +// +// +marshal +type Getdents64Req struct { + DirFD FDID + // Count is the number of bytes to read. A negative value of Count is used to + // indicate that the implementation must lseek(0, SEEK_SET) before calling + // getdents64(2). Implementations must use the absolute value of Count to + // determine the number of bytes to read. + Count int32 +} + +// Dirent64 is analogous to struct linux_dirent64. +type Dirent64 struct { + Ino primitive.Uint64 + DevMinor primitive.Uint32 + DevMajor primitive.Uint32 + Off primitive.Uint64 + Type primitive.Uint8 + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (d *Dirent64) SizeBytes() int { + return d.Ino.SizeBytes() + d.DevMinor.SizeBytes() + d.DevMajor.SizeBytes() + d.Off.SizeBytes() + d.Type.SizeBytes() + d.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (d *Dirent64) MarshalBytes(dst []byte) { + d.Ino.MarshalUnsafe(dst) + dst = dst[d.Ino.SizeBytes():] + d.DevMinor.MarshalUnsafe(dst) + dst = dst[d.DevMinor.SizeBytes():] + d.DevMajor.MarshalUnsafe(dst) + dst = dst[d.DevMajor.SizeBytes():] + d.Off.MarshalUnsafe(dst) + dst = dst[d.Off.SizeBytes():] + d.Type.MarshalUnsafe(dst) + dst = dst[d.Type.SizeBytes():] + d.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (d *Dirent64) UnmarshalBytes(src []byte) { + d.Ino.UnmarshalUnsafe(src) + src = src[d.Ino.SizeBytes():] + d.DevMinor.UnmarshalUnsafe(src) + src = src[d.DevMinor.SizeBytes():] + d.DevMajor.UnmarshalUnsafe(src) + src = src[d.DevMajor.SizeBytes():] + d.Off.UnmarshalUnsafe(src) + src = src[d.Off.SizeBytes():] + d.Type.UnmarshalUnsafe(src) + src = src[d.Type.SizeBytes():] + d.Name.UnmarshalBytes(src) +} + +// Getdents64Resp is used to communicate getdents64 results. +type Getdents64Resp struct { + Dirents []Dirent64 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *Getdents64Resp) SizeBytes() int { + ret := (*primitive.Uint32)(nil).SizeBytes() + for i := range g.Dirents { + ret += g.Dirents[i].SizeBytes() + } + return ret +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *Getdents64Resp) MarshalBytes(dst []byte) { + numDirents := primitive.Uint32(len(g.Dirents)) + numDirents.MarshalUnsafe(dst) + dst = dst[numDirents.SizeBytes():] + for i := range g.Dirents { + g.Dirents[i].MarshalBytes(dst) + dst = dst[g.Dirents[i].SizeBytes():] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *Getdents64Resp) UnmarshalBytes(src []byte) { + var numDirents primitive.Uint32 + numDirents.UnmarshalUnsafe(src) + if cap(g.Dirents) < int(numDirents) { + g.Dirents = make([]Dirent64, numDirents) + } else { + g.Dirents = g.Dirents[:numDirents] + } + + src = src[numDirents.SizeBytes():] + for i := range g.Dirents { + g.Dirents[i].UnmarshalBytes(src) + src = src[g.Dirents[i].SizeBytes():] + } +} + +// FGetXattrReq is used to make FGetXattr requests. The response to this is +// just a SizedString containing the xattr value. +type FGetXattrReq struct { + FD FDID + BufSize primitive.Uint32 + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *FGetXattrReq) SizeBytes() int { + return g.FD.SizeBytes() + g.BufSize.SizeBytes() + g.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *FGetXattrReq) MarshalBytes(dst []byte) { + g.FD.MarshalUnsafe(dst) + dst = dst[g.FD.SizeBytes():] + g.BufSize.MarshalUnsafe(dst) + dst = dst[g.BufSize.SizeBytes():] + g.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *FGetXattrReq) UnmarshalBytes(src []byte) { + g.FD.UnmarshalUnsafe(src) + src = src[g.FD.SizeBytes():] + g.BufSize.UnmarshalUnsafe(src) + src = src[g.BufSize.SizeBytes():] + g.Name.UnmarshalBytes(src) +} + +// FGetXattrResp is used to respond to FGetXattr request. +type FGetXattrResp struct { + Value SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (g *FGetXattrResp) SizeBytes() int { + return g.Value.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (g *FGetXattrResp) MarshalBytes(dst []byte) { + g.Value.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (g *FGetXattrResp) UnmarshalBytes(src []byte) { + g.Value.UnmarshalBytes(src) +} + +// FSetXattrReq is used to make FSetXattr requests. It has no response. +type FSetXattrReq struct { + FD FDID + Flags primitive.Uint32 + Name SizedString + Value SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *FSetXattrReq) SizeBytes() int { + return s.FD.SizeBytes() + s.Flags.SizeBytes() + s.Name.SizeBytes() + s.Value.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *FSetXattrReq) MarshalBytes(dst []byte) { + s.FD.MarshalUnsafe(dst) + dst = dst[s.FD.SizeBytes():] + s.Flags.MarshalUnsafe(dst) + dst = dst[s.Flags.SizeBytes():] + s.Name.MarshalBytes(dst) + dst = dst[s.Name.SizeBytes():] + s.Value.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *FSetXattrReq) UnmarshalBytes(src []byte) { + s.FD.UnmarshalUnsafe(src) + src = src[s.FD.SizeBytes():] + s.Flags.UnmarshalUnsafe(src) + src = src[s.Flags.SizeBytes():] + s.Name.UnmarshalBytes(src) + src = src[s.Name.SizeBytes():] + s.Value.UnmarshalBytes(src) +} + +// FRemoveXattrReq is used to make FRemoveXattr requests. It has no response. +type FRemoveXattrReq struct { + FD FDID + Name SizedString +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *FRemoveXattrReq) SizeBytes() int { + return r.FD.SizeBytes() + r.Name.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *FRemoveXattrReq) MarshalBytes(dst []byte) { + r.FD.MarshalUnsafe(dst) + dst = dst[r.FD.SizeBytes():] + r.Name.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) { + r.FD.UnmarshalUnsafe(src) + src = src[r.FD.SizeBytes():] + r.Name.UnmarshalBytes(src) +} + +// FListXattrReq is used to make FListXattr requests. +// +// +marshal +type FListXattrReq struct { + FD FDID + _ uint32 + Size uint64 +} + +// FListXattrResp is used to respond to FListXattr requests. +type FListXattrResp struct { + Xattrs StringArray +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (l *FListXattrResp) SizeBytes() int { + return l.Xattrs.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (l *FListXattrResp) MarshalBytes(dst []byte) { + l.Xattrs.MarshalBytes(dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (l *FListXattrResp) UnmarshalBytes(src []byte) { + l.Xattrs.UnmarshalBytes(src) +} diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go new file mode 100644 index 000000000..3868dfa08 --- /dev/null +++ b/pkg/lisafs/sample_message.go @@ -0,0 +1,110 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "math/rand" + + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + +// MsgSimple is a sample packed struct which can be used to test message passing. +// +// +marshal slice:Msg1Slice +type MsgSimple struct { + A uint16 + B uint16 + C uint32 + D uint64 +} + +// Randomize randomizes the contents of m. +func (m *MsgSimple) Randomize() { + m.A = uint16(rand.Uint32()) + m.B = uint16(rand.Uint32()) + m.C = rand.Uint32() + m.D = rand.Uint64() +} + +// MsgDynamic is a sample dynamic struct which can be used to test message passing. +// +// +marshal dynamic +type MsgDynamic struct { + N primitive.Uint32 + Arr []MsgSimple +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MsgDynamic) SizeBytes() int { + return m.N.SizeBytes() + + (int(m.N) * (*MsgSimple)(nil).SizeBytes()) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MsgDynamic) MarshalBytes(dst []byte) { + m.N.MarshalUnsafe(dst) + dst = dst[m.N.SizeBytes():] + MarshalUnsafeMsg1Slice(m.Arr, dst) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MsgDynamic) UnmarshalBytes(src []byte) { + m.N.UnmarshalUnsafe(src) + src = src[m.N.SizeBytes():] + m.Arr = make([]MsgSimple, m.N) + UnmarshalUnsafeMsg1Slice(m.Arr, src) +} + +// Randomize randomizes the contents of m. +func (m *MsgDynamic) Randomize(arrLen int) { + m.N = primitive.Uint32(arrLen) + m.Arr = make([]MsgSimple, arrLen) + for i := 0; i < arrLen; i++ { + m.Arr[i].Randomize() + } +} + +// P9Version mimics p9.TVersion and p9.Rversion. +// +// +marshal dynamic +type P9Version struct { + MSize primitive.Uint32 + Version string +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (v *P9Version) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + (*primitive.Uint16)(nil).SizeBytes() + len(v.Version) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (v *P9Version) MarshalBytes(dst []byte) { + v.MSize.MarshalUnsafe(dst) + dst = dst[v.MSize.SizeBytes():] + versionLen := primitive.Uint16(len(v.Version)) + versionLen.MarshalUnsafe(dst) + dst = dst[versionLen.SizeBytes():] + copy(dst, v.Version) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (v *P9Version) UnmarshalBytes(src []byte) { + v.MSize.UnmarshalUnsafe(src) + src = src[v.MSize.SizeBytes():] + var versionLen primitive.Uint16 + versionLen.UnmarshalUnsafe(src) + src = src[versionLen.SizeBytes():] + v.Version = string(src[:versionLen]) +} diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go new file mode 100644 index 000000000..7515355ec --- /dev/null +++ b/pkg/lisafs/server.go @@ -0,0 +1,113 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "gvisor.dev/gvisor/pkg/sync" +) + +// Server serves a filesystem tree. Multiple connections on different mount +// points can be started on a server. The server provides utilities to safely +// modify the filesystem tree across its connections (mount points). Note that +// it does not support synchronizing filesystem tree mutations across other +// servers serving the same filesystem subtree. Server also manages the +// lifecycle of all connections. +type Server struct { + // connWg counts the number of active connections being tracked. + connWg sync.WaitGroup + + // RenameMu synchronizes rename operations within this filesystem tree. + RenameMu sync.RWMutex + + // handlers is a list of RPC handlers which can be indexed by the handler's + // corresponding MID. + handlers []RPCHandler + + // mountPoints keeps track of all the mount points this server serves. + mpMu sync.RWMutex + mountPoints []*ControlFD + + // impl is the server implementation which embeds this server. + impl ServerImpl +} + +// Init must be called before first use of server. +func (s *Server) Init(impl ServerImpl) { + s.impl = impl + s.handlers = handlers[:] +} + +// InitTestOnly is the same as Init except that it allows to swap out the +// underlying handlers with something custom. This is for test only. +func (s *Server) InitTestOnly(impl ServerImpl, handlers []RPCHandler) { + s.impl = impl + s.handlers = handlers +} + +// WithRenameReadLock invokes fn with the server's rename mutex locked for +// reading. This ensures that no rename operations occur concurrently. +func (s *Server) WithRenameReadLock(fn func() error) error { + s.RenameMu.RLock() + err := fn() + s.RenameMu.RUnlock() + return err +} + +// StartConnection starts the connection on a separate goroutine and tracks it. +func (s *Server) StartConnection(c *Connection) { + s.connWg.Add(1) + go func() { + c.Run() + s.connWg.Done() + }() +} + +// Wait waits for all connections started via StartConnection() to terminate. +func (s *Server) Wait() { + s.connWg.Wait() +} + +func (s *Server) addMountPoint(root *ControlFD) { + s.mpMu.Lock() + defer s.mpMu.Unlock() + s.mountPoints = append(s.mountPoints, root) +} + +func (s *Server) forEachMountPoint(fn func(root *ControlFD)) { + s.mpMu.RLock() + defer s.mpMu.RUnlock() + for _, mp := range s.mountPoints { + fn(mp) + } +} + +// ServerImpl contains the implementation details for a Server. +// Implementations of ServerImpl should contain their associated Server by +// value as their first field. +type ServerImpl interface { + // Mount is called when a Mount RPC is made. It mounts the connection at + // mountPath. + // + // Precondition: mountPath == path.Clean(mountPath). + Mount(c *Connection, mountPath string) (ControlFDImpl, Inode, error) + + // SupportedMessages returns a list of messages that the server + // implementation supports. + SupportedMessages() []MID + + // MaxMessageSize is the maximum payload length (in bytes) that can be sent + // to this server implementation. + MaxMessageSize() uint32 +} diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go new file mode 100644 index 000000000..88210242f --- /dev/null +++ b/pkg/lisafs/sock.go @@ -0,0 +1,208 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/unet" +) + +var ( + sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes()) +) + +// sockHeader is the header present in front of each message received on a UDS. +// +// +marshal +type sockHeader struct { + payloadLen uint32 + message MID + _ uint16 // Need to make struct packed. +} + +// sockCommunicator implements Communicator. This is not thread safe. +type sockCommunicator struct { + fdTracker + sock *unet.Socket + buf []byte +} + +var _ Communicator = (*sockCommunicator)(nil) + +func newSockComm(sock *unet.Socket) *sockCommunicator { + return &sockCommunicator{ + sock: sock, + buf: make([]byte, sockHeaderLen), + } +} + +func (s *sockCommunicator) FD() int { + return s.sock.FD() +} + +func (s *sockCommunicator) destroy() { + s.sock.Close() +} + +func (s *sockCommunicator) shutdown() { + if err := s.sock.Shutdown(); err != nil { + log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err) + } +} + +func (s *sockCommunicator) resizeBuf(size uint32) { + if cap(s.buf) < int(size) { + s.buf = s.buf[:cap(s.buf)] + s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...) + } else { + s.buf = s.buf[:size] + } +} + +// PayloadBuf implements Communicator.PayloadBuf. +func (s *sockCommunicator) PayloadBuf(size uint32) []byte { + s.resizeBuf(sockHeaderLen + size) + return s.buf[sockHeaderLen : sockHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil { + return 0, 0, err + } + + return s.rcvMsg(wantFDs) +} + +// sndPrepopulatedMsg assumes that s.buf has already been populated with +// `payloadLen` bytes of data. +func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error { + header := sockHeader{payloadLen: payloadLen, message: m} + header.MarshalUnsafe(s.buf) + dataLen := sockHeaderLen + payloadLen + return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds) +} + +// writeTo writes the passed iovec to the UDS and donates any passed FDs. +func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error { + w := sock.Writer(true) + if len(fds) > 0 { + w.PackFDs(fds...) + } + + fdsUnpacked := false + for n := 0; n < dataLen; { + cur, err := w.WriteVec(iovec) + if err != nil { + return err + } + n += cur + + // Fast common path. + if n >= dataLen { + break + } + + // Consume iovecs. + for consumed := 0; consumed < cur; { + if len(iovec[0]) <= cur-consumed { + consumed += len(iovec[0]) + iovec = iovec[1:] + } else { + iovec[0] = iovec[0][cur-consumed:] + break + } + } + + if n > 0 && !fdsUnpacked { + // Don't resend any control message. + fdsUnpacked = true + w.UnpackFDs() + } + } + return nil +} + +// rcvMsg reads the message header and payload from the UDS. It also populates +// fds with any donated FDs. +func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) { + fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs) + if err != nil { + return 0, 0, err + } + for _, fd := range fds { + s.TrackFD(fd) + } + + var header sockHeader + header.UnmarshalUnsafe(s.buf) + + // No payload? We are done. + if header.payloadLen == 0 { + return header.message, 0, nil + } + + if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil { + return 0, 0, err + } + + return header.message, header.payloadLen, nil +} + +// readFrom fills the passed buffer with data from the socket. It also returns +// any donated FDs. +func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) { + r := sock.Reader(true) + r.EnableFDs(int(wantFDs)) + + var ( + fds []int + fdInit bool + ) + n := len(buf) + for got := 0; got < n; { + cur, err := r.ReadVec([][]byte{buf[got:]}) + + // Ignore EOF if cur > 0. + if err != nil && (err != io.EOF || cur == 0) { + r.CloseFDs() + return nil, err + } + + if !fdInit && cur > 0 { + fds, err = r.ExtractFDs() + if err != nil { + return nil, err + } + + fdInit = true + r.EnableFDs(0) + } + + got += cur + } + return fds, nil +} + +func closeFDs(fds []int) { + for _, fd := range fds { + if fd >= 0 { + unix.Close(fd) + } + } +} diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go new file mode 100644 index 000000000..387f4b7a8 --- /dev/null +++ b/pkg/lisafs/sock_test.go @@ -0,0 +1,217 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) { + sock1, sock2, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer sock1.Close() + defer sock2.Close() + + var testWg sync.WaitGroup + testWg.Add(2) + + go func() { + fun1(newSockComm(sock1)) + testWg.Done() + }() + + go func() { + fun2(newSockComm(sock2)) + testWg.Done() + }() + + testWg.Wait() +} + +func TestReadWrite(t *testing.T) { + // Create random data to send. + n := 10000 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + runSocketTest(t, func(comm *sockCommunicator) { + // Scatter that data into two parts using Iovecs while sending. + mid := n / 2 + if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + if _, err := readFrom(comm.sock, gotData, 0); err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + + // Make sure we got the right data. + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + }) +} + +func TestFDDonation(t *testing.T) { + n := 10 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + // Try donating FDs to these files. + path1 := "/dev/null" + path2 := "/dev" + path3 := "/dev/random" + + runSocketTest(t, func(comm *sockCommunicator) { + devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0) + defer unix.Close(devNullFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path1, err) + } + devFD, err := unix.Open(path2, unix.O_RDONLY, 0) + defer unix.Close(devFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0) + defer unix.Close(devRandomFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + fds, err := readFrom(comm.sock, gotData, 3) + if err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + defer closeFDs(fds[:]) + + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + + if len(fds) != 3 { + t.Fatalf("wanted 3 FD, got %d", len(fds)) + } + + // Check that the FDs actually point to the correct file. + compareFDWithFile(t, fds[0], path1) + compareFDWithFile(t, fds[1], path2) + compareFDWithFile(t, fds[2], path3) + }) +} + +func compareFDWithFile(t *testing.T, fd int, path string) { + var want unix.Stat_t + if err := unix.Stat(path, &want); err != nil { + t.Fatalf("stat(%s) failed: %v", path, err) + } + + var got unix.Stat_t + if err := unix.Fstat(fd, &got); err != nil { + t.Fatalf("fstat on donated FD failed: %v", err) + } + + if got.Ino != want.Ino || got.Dev != want.Dev { + t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got) + } +} + +func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error { + var payloadLen uint32 + if msg != nil { + payloadLen = uint32(msg.SizeBytes()) + msg.MarshalUnsafe(comm.PayloadBuf(payloadLen)) + } + return comm.sndPrepopulatedMsg(m, payloadLen, nil) +} + +func TestSndRcvMessage(t *testing.T) { + req := &MsgSimple{} + req.Randomize() + reqM := MID(1) + + // Create a massive random response. + var resp MsgDynamic + resp.Randomize(100) + respM := MID(2) + + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, req); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, &resp) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, req) + if err := testSndMsg(comm, respM, &resp); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func TestSndRcvMessageNoPayload(t *testing.T) { + reqM := MID(1) + respM := MID(2) + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, nil) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, nil) + if err := testSndMsg(comm, respM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) { + gotM, payloadLen, err := comm.rcvMsg(0) + if err != nil { + t.Fatalf("readMessageFrom failed: %v", err) + } + if gotM != wantM { + t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM) + } + if wantMsg == nil { + if payloadLen != 0 { + t.Errorf("no payload expect but got %d bytes", payloadLen) + } + } else { + gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable) + gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + if !reflect.DeepEqual(wantMsg, gotMsg) { + t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg) + } + } +} diff --git a/pkg/lisafs/testsuite/BUILD b/pkg/lisafs/testsuite/BUILD new file mode 100644 index 000000000..b4a542b3a --- /dev/null +++ b/pkg/lisafs/testsuite/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_library( + name = "testsuite", + testonly = True, + srcs = ["testsuite.go"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/lisafs", + "//pkg/unet", + "@com_github_syndtr_gocapability//capability:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go new file mode 100644 index 000000000..5fc7c364d --- /dev/null +++ b/pkg/lisafs/testsuite/testsuite.go @@ -0,0 +1,637 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testsuite provides a integration testing suite for lisafs. +// These tests are intended for servers serving the local filesystem. +package testsuite + +import ( + "bytes" + "fmt" + "io/ioutil" + "math/rand" + "os" + "testing" + "time" + + "github.com/syndtr/gocapability/capability" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/lisafs" + "gvisor.dev/gvisor/pkg/unet" +) + +// Tester is the client code using this test suite. This interface abstracts +// away all the caller specific details. +type Tester interface { + // NewServer returns a new instance of the tester server. + NewServer(t *testing.T) *lisafs.Server + + // LinkSupported returns true if the backing server supports LinkAt. + LinkSupported() bool + + // SetUserGroupIDSupported returns true if the backing server supports + // changing UID/GID for files. + SetUserGroupIDSupported() bool +} + +// RunAllLocalFSTests runs all local FS tests as subtests. +func RunAllLocalFSTests(t *testing.T, tester Tester) { + for name, testFn := range localFSTests { + t.Run(name, func(t *testing.T) { + runServerClient(t, tester, testFn) + }) + } +} + +type testFunc func(context.Context, *testing.T, Tester, lisafs.ClientFD) + +var localFSTests map[string]testFunc = map[string]testFunc{ + "Stat": testStat, + "RegularFileIO": testRegularFileIO, + "RegularFileOpen": testRegularFileOpen, + "SetStat": testSetStat, + "Allocate": testAllocate, + "StatFS": testStatFS, + "Unlink": testUnlink, + "Symlink": testSymlink, + "HardLink": testHardLink, + "Walk": testWalk, + "Rename": testRename, + "Mknod": testMknod, + "Getdents": testGetdents, +} + +func runServerClient(t *testing.T, tester Tester, testFn testFunc) { + mountPath, err := ioutil.TempDir(os.Getenv("TEST_TMPDIR"), "") + if err != nil { + t.Fatalf("creation of temporary mountpoint failed: %v", err) + } + defer os.RemoveAll(mountPath) + + // fsgofer should run with a umask of 0, because we want to preserve file + // modes exactly for testing purposes. + unix.Umask(0) + + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + + server := tester.NewServer(t) + conn, err := server.CreateConnection(serverSocket, false /* readonly */) + if err != nil { + t.Fatalf("starting connection failed: %v", err) + return + } + server.StartConnection(conn) + + c, root, err := lisafs.NewClient(clientSocket, mountPath) + if err != nil { + t.Fatalf("client creation failed: %v", err) + } + + if !root.ControlFD.Ok() { + t.Fatalf("root control FD is not valid") + } + rootFile := c.NewFD(root.ControlFD) + + ctx := context.Background() + testFn(ctx, t, tester, rootFile) + closeFD(ctx, t, rootFile) + + c.Close() // This should trigger client and server shutdown. + server.Wait() +} + +func closeFD(ctx context.Context, t testing.TB, fdLisa lisafs.ClientFD) { + if err := fdLisa.Close(ctx); err != nil { + t.Errorf("failed to close FD: %v", err) + } +} + +func statTo(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, stat *linux.Statx) { + if err := fdLisa.StatTo(ctx, stat); err != nil { + t.Fatalf("stat failed: %v", err) + } +} + +func openCreateFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx, lisafs.ClientFD, int) { + child, childFD, childHostFD, err := fdLisa.OpenCreateAt(ctx, name, unix.O_RDWR, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("OpenCreateAt failed: %v", err) + } + if childHostFD == -1 { + t.Error("no host FD donated") + } + client := fdLisa.Client() + return client.NewFD(child.ControlFD), child.Stat, fdLisa.Client().NewFD(childFD), childHostFD +} + +func openFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, flags uint32, isReg bool) (lisafs.ClientFD, int) { + newFD, hostFD, err := fdLisa.OpenAt(ctx, flags) + if err != nil { + t.Fatalf("OpenAt failed: %v", err) + } + if hostFD == -1 && isReg { + t.Error("no host FD donated") + } + return fdLisa.Client().NewFD(newFD), hostFD +} + +func unlinkFile(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, isDir bool) { + var flags uint32 + if isDir { + flags = unix.AT_REMOVEDIR + } + if err := dir.UnlinkAt(ctx, name, flags); err != nil { + t.Errorf("unlinking file %s failed: %v", name, err) + } +} + +func symlink(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name, target string) (lisafs.ClientFD, linux.Statx) { + linkIno, err := dir.SymlinkAt(ctx, name, target, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("symlink failed: %v", err) + } + return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat +} + +func link(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, target lisafs.ClientFD) (lisafs.ClientFD, linux.Statx) { + linkIno, err := dir.LinkAt(ctx, target.ID(), name) + if err != nil { + t.Fatalf("link failed: %v", err) + } + return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat +} + +func mkdir(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) { + childIno, err := dir.MkdirAt(ctx, name, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid())) + if err != nil { + t.Fatalf("mkdir failed: %v", err) + } + return dir.Client().NewFD(childIno.ControlFD), childIno.Stat +} + +func mknod(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) { + nodeIno, err := dir.MknodAt(ctx, name, unix.S_IFREG|0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()), 0, 0) + if err != nil { + t.Fatalf("mknod failed: %v", err) + } + return dir.Client().NewFD(nodeIno.ControlFD), nodeIno.Stat +} + +func walk(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []lisafs.Inode { + _, inodes, err := dir.WalkMultiple(ctx, names) + if err != nil { + t.Fatalf("walk failed while trying to walk components %+v: %v", names, err) + } + return inodes +} + +func walkStat(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []linux.Statx { + stats, err := dir.WalkStat(ctx, names) + if err != nil { + t.Fatalf("walk failed while trying to walk components %+v: %v", names, err) + } + return stats +} + +func writeFD(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, buf []byte) error { + count, err := fdLisa.Write(ctx, buf, off) + if err != nil { + return err + } + if int(count) != len(buf) { + t.Errorf("partial write: buf size = %d, written = %d", len(buf), count) + } + return nil +} + +func readFDAndCmp(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, want []byte) { + buf := make([]byte, len(want)) + n, err := fdLisa.Read(ctx, buf, off) + if err != nil { + t.Errorf("read failed: %v", err) + return + } + if int(n) != len(want) { + t.Errorf("partial read: buf size = %d, read = %d", len(want), n) + return + } + if bytes.Compare(buf, want) != 0 { + t.Errorf("bytes read differ from what was expected: want = %v, got = %v", want, buf) + } +} + +func allocateAndVerify(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, length uint64) { + if err := fdLisa.Allocate(ctx, 0, off, length); err != nil { + t.Fatalf("fallocate failed: %v", err) + } + + var stat linux.Statx + statTo(ctx, t, fdLisa, &stat) + if want := off + length; stat.Size != want { + t.Errorf("incorrect file size after allocate: expected %d, got %d", off+length, stat.Size) + } +} + +func cmpStatx(t *testing.T, want, got linux.Statx) { + if got.Mask&unix.STATX_MODE != 0 && want.Mask&unix.STATX_MODE != 0 { + if got.Mode != want.Mode { + t.Errorf("mode differs: want %d, got %d", want.Mode, got.Mode) + } + } + if got.Mask&unix.STATX_INO != 0 && want.Mask&unix.STATX_INO != 0 { + if got.Ino != want.Ino { + t.Errorf("inode number differs: want %d, got %d", want.Ino, got.Ino) + } + } + if got.Mask&unix.STATX_NLINK != 0 && want.Mask&unix.STATX_NLINK != 0 { + if got.Nlink != want.Nlink { + t.Errorf("nlink differs: want %d, got %d", want.Nlink, got.Nlink) + } + } + if got.Mask&unix.STATX_UID != 0 && want.Mask&unix.STATX_UID != 0 { + if got.UID != want.UID { + t.Errorf("UID differs: want %d, got %d", want.UID, got.UID) + } + } + if got.Mask&unix.STATX_GID != 0 && want.Mask&unix.STATX_GID != 0 { + if got.GID != want.GID { + t.Errorf("GID differs: want %d, got %d", want.GID, got.GID) + } + } + if got.Mask&unix.STATX_SIZE != 0 && want.Mask&unix.STATX_SIZE != 0 { + if got.Size != want.Size { + t.Errorf("size differs: want %d, got %d", want.Size, got.Size) + } + } + if got.Mask&unix.STATX_BLOCKS != 0 && want.Mask&unix.STATX_BLOCKS != 0 { + if got.Blocks != want.Blocks { + t.Errorf("blocks differs: want %d, got %d", want.Blocks, got.Blocks) + } + } + if got.Mask&unix.STATX_ATIME != 0 && want.Mask&unix.STATX_ATIME != 0 { + if got.Atime != want.Atime { + t.Errorf("atime differs: want %d, got %d", want.Atime, got.Atime) + } + } + if got.Mask&unix.STATX_MTIME != 0 && want.Mask&unix.STATX_MTIME != 0 { + if got.Mtime != want.Mtime { + t.Errorf("mtime differs: want %d, got %d", want.Mtime, got.Mtime) + } + } + if got.Mask&unix.STATX_CTIME != 0 && want.Mask&unix.STATX_CTIME != 0 { + if got.Ctime != want.Ctime { + t.Errorf("ctime differs: want %d, got %d", want.Ctime, got.Ctime) + } + } +} + +func hasCapability(c capability.Cap) bool { + caps, err := capability.NewPid2(os.Getpid()) + if err != nil { + return false + } + if err := caps.Load(); err != nil { + return false + } + return caps.Get(capability.EFFECTIVE, c) +} + +func testStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + var rootStat linux.Statx + if err := root.StatTo(ctx, &rootStat); err != nil { + t.Errorf("stat on the root dir failed: %v", err) + } + + if ftype := rootStat.Mode & unix.S_IFMT; ftype != unix.S_IFDIR { + t.Errorf("root inode is not a directory, file type = %d", ftype) + } +} + +func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + // Test Read/Write RPCs with 2MB of data to test IO in chunks. + data := make([]byte, 1<<21) + rand.Read(data) + if err := writeFD(ctx, t, fd, 0, data); err != nil { + t.Fatalf("write failed: %v", err) + } + readFDAndCmp(ctx, t, fd, 0, data) + readFDAndCmp(ctx, t, fd, 50, data[50:]) + + // Make sure the host FD is configured properly. + hostReadData := make([]byte, len(data)) + if n, err := unix.Pread(hostFD, hostReadData, 0); err != nil { + t.Errorf("host read failed: %v", err) + } else if n != len(hostReadData) { + t.Errorf("partial read: buf size = %d, read = %d", len(hostReadData), n) + } else if bytes.Compare(hostReadData, data) != 0 { + t.Errorf("bytes read differ from what was expected: want = %v, got = %v", data, hostReadData) + } + + // Test syncing the writable FD. + if err := fd.Sync(ctx); err != nil { + t.Errorf("syncing the FD failed: %v", err) + } +} + +func testRegularFileOpen(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + // Open a readonly FD and try writing to it to get an EBADF. + roFile, roHostFD := openFile(ctx, t, controlFile, unix.O_RDONLY, true /* isReg */) + defer closeFD(ctx, t, roFile) + defer unix.Close(roHostFD) + if err := writeFD(ctx, t, roFile, 0, []byte{1, 2, 3}); err != unix.EBADF { + t.Errorf("writing to read only FD should generate EBADF, but got %v", err) + } +} + +func testSetStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + now := time.Now() + wantStat := linux.Statx{ + Mask: unix.STATX_MODE | unix.STATX_ATIME | unix.STATX_MTIME | unix.STATX_SIZE, + Mode: 0760, + UID: uint32(unix.Getuid()), + GID: uint32(unix.Getgid()), + Size: 50, + Atime: linux.NsecToStatxTimestamp(now.UnixNano()), + Mtime: linux.NsecToStatxTimestamp(now.UnixNano()), + } + if tester.SetUserGroupIDSupported() { + wantStat.Mask |= unix.STATX_UID | unix.STATX_GID + } + failureMask, failureErr, err := controlFile.SetStat(ctx, &wantStat) + if err != nil { + t.Fatalf("setstat failed: %v", err) + } + if failureMask != 0 { + t.Fatalf("some setstat operations failed: failureMask = %#b, failureErr = %v", failureMask, failureErr) + } + + // Verify that attributes were updated. + var gotStat linux.Statx + statTo(ctx, t, controlFile, &gotStat) + if gotStat.Mode&07777 != wantStat.Mode || + gotStat.Size != wantStat.Size || + gotStat.Atime.ToNsec() != wantStat.Atime.ToNsec() || + gotStat.Mtime.ToNsec() != wantStat.Mtime.ToNsec() || + (tester.SetUserGroupIDSupported() && (uint32(gotStat.UID) != wantStat.UID || uint32(gotStat.GID) != wantStat.GID)) { + t.Errorf("setStat did not update file correctly: setStat = %+v, stat = %+v", wantStat, gotStat) + } +} + +func testAllocate(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + allocateAndVerify(ctx, t, fd, 0, 40) + allocateAndVerify(ctx, t, fd, 20, 100) +} + +func testStatFS(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + var statFS lisafs.StatFS + if err := root.StatFSTo(ctx, &statFS); err != nil { + t.Errorf("statfs failed: %v", err) + } +} + +func testUnlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + unlinkFile(ctx, t, root, name, false /* isDir */) + if inodes := walk(ctx, t, root, []string{name}); len(inodes) > 0 { + t.Errorf("deleted file should not be generating inodes on walk: %+v", inodes) + } +} + +func testSymlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + target := "/tmp/some/path" + name := "symlinkFile" + link, linkStat := symlink(ctx, t, root, name, target) + defer closeFD(ctx, t, link) + + if linkStat.Mode&unix.S_IFMT != unix.S_IFLNK { + t.Errorf("stat return from symlink RPC indicates that the inode is not a symlink: mode = %d", linkStat.Mode) + } + + if gotTarget, err := link.ReadLinkAt(ctx); err != nil { + t.Fatalf("readlink failed: %v", err) + } else if gotTarget != target { + t.Errorf("readlink return incorrect target: expected %q, got %q", target, gotTarget) + } +} + +func testHardLink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + if !tester.LinkSupported() { + t.Skipf("server does not support LinkAt RPC") + } + if !hasCapability(capability.CAP_DAC_READ_SEARCH) { + t.Skipf("TestHardLink requires CAP_DAC_READ_SEARCH, running as %d", unix.Getuid()) + } + name := "tempFile" + controlFile, fileIno, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, controlFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + link, linkStat := link(ctx, t, root, name, controlFile) + defer closeFD(ctx, t, link) + + if linkStat.Ino != fileIno.Ino { + t.Errorf("hard linked files have different inode numbers: %d %d", linkStat.Ino, fileIno.Ino) + } + if linkStat.DevMinor != fileIno.DevMinor { + t.Errorf("hard linked files have different minor device numbers: %d %d", linkStat.DevMinor, fileIno.DevMinor) + } + if linkStat.DevMajor != fileIno.DevMajor { + t.Errorf("hard linked files have different major device numbers: %d %d", linkStat.DevMajor, fileIno.DevMajor) + } +} + +func testWalk(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + // Create 10 nested directories. + n := 10 + curDir := root + + dirNames := make([]string, 0, n) + for i := 0; i < n; i++ { + name := fmt.Sprintf("tmpdir-%d", i) + childDir, _ := mkdir(ctx, t, curDir, name) + defer closeFD(ctx, t, childDir) + defer unlinkFile(ctx, t, curDir, name, true /* isDir */) + + curDir = childDir + dirNames = append(dirNames, name) + } + + // Walk all these directories. Add some junk at the end which should not be + // walked on. + dirNames = append(dirNames, []string{"a", "b", "c"}...) + inodes := walk(ctx, t, root, dirNames) + if len(inodes) != n { + t.Errorf("walk returned the incorrect number of inodes: wanted %d, got %d", n, len(inodes)) + } + + // Close all control FDs and collect stat results for all dirs including + // the root directory. + dirStats := make([]linux.Statx, 0, n+1) + var stat linux.Statx + statTo(ctx, t, root, &stat) + dirStats = append(dirStats, stat) + for _, inode := range inodes { + dirStats = append(dirStats, inode.Stat) + closeFD(ctx, t, root.Client().NewFD(inode.ControlFD)) + } + + // Test WalkStat which additonally returns Statx for root because the first + // path component is "". + dirNames = append([]string{""}, dirNames...) + gotStats := walkStat(ctx, t, root, dirNames) + if len(gotStats) != len(dirStats) { + t.Errorf("walkStat returned the incorrect number of statx: wanted %d, got %d", len(dirStats), len(gotStats)) + } else { + for i := range gotStats { + cmpStatx(t, dirStats[i], gotStats[i]) + } + } +} + +func testRename(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "tempFile" + tempFile, _, fd, hostFD := openCreateFile(ctx, t, root, name) + defer closeFD(ctx, t, tempFile) + defer closeFD(ctx, t, fd) + defer unix.Close(hostFD) + + tempDir, _ := mkdir(ctx, t, root, "tempDir") + defer closeFD(ctx, t, tempDir) + + // Move tempFile into tempDir. + if err := tempFile.RenameTo(ctx, tempDir.ID(), "movedFile"); err != nil { + t.Fatalf("rename failed: %v", err) + } + + inodes := walkStat(ctx, t, root, []string{"tempDir", "movedFile"}) + if len(inodes) != 2 { + t.Errorf("expected 2 files on walk but only found %d", len(inodes)) + } +} + +func testMknod(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + name := "namedPipe" + pipeFile, pipeStat := mknod(ctx, t, root, name) + defer closeFD(ctx, t, pipeFile) + + var stat linux.Statx + statTo(ctx, t, pipeFile, &stat) + + if stat.Mode != pipeStat.Mode { + t.Errorf("mknod mode is incorrect: want %d, got %d", pipeStat.Mode, stat.Mode) + } + if stat.UID != pipeStat.UID { + t.Errorf("mknod UID is incorrect: want %d, got %d", pipeStat.UID, stat.UID) + } + if stat.GID != pipeStat.GID { + t.Errorf("mknod GID is incorrect: want %d, got %d", pipeStat.GID, stat.GID) + } +} + +func testGetdents(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) { + tempDir, _ := mkdir(ctx, t, root, "tempDir") + defer closeFD(ctx, t, tempDir) + defer unlinkFile(ctx, t, root, "tempDir", true /* isDir */) + + // Create 10 files in tempDir. + n := 10 + fileStats := make(map[string]linux.Statx) + for i := 0; i < n; i++ { + name := fmt.Sprintf("file-%d", i) + newFile, fileStat := mknod(ctx, t, tempDir, name) + defer closeFD(ctx, t, newFile) + defer unlinkFile(ctx, t, tempDir, name, false /* isDir */) + + fileStats[name] = fileStat + } + + // Use opened directory FD for getdents. + openDirFile, _ := openFile(ctx, t, tempDir, unix.O_RDONLY, false /* isReg */) + defer closeFD(ctx, t, openDirFile) + + dirents := make([]lisafs.Dirent64, 0, n) + for i := 0; i < n+2; i++ { + gotDirents, err := openDirFile.Getdents64(ctx, 40) + if err != nil { + t.Fatalf("getdents failed: %v", err) + } + if len(gotDirents) == 0 { + break + } + for _, dirent := range gotDirents { + if dirent.Name != "." && dirent.Name != ".." { + dirents = append(dirents, dirent) + } + } + } + + if len(dirents) != n { + t.Errorf("got incorrect number of dirents: wanted %d, got %d", n, len(dirents)) + } + for _, dirent := range dirents { + stat, ok := fileStats[string(dirent.Name)] + if !ok { + t.Errorf("received a dirent that was not created: %+v", dirent) + continue + } + + if dirent.Type != unix.DT_REG { + t.Errorf("dirent type of %s is incorrect: %d", dirent.Name, dirent.Type) + } + if uint64(dirent.Ino) != stat.Ino { + t.Errorf("dirent ino of %s is incorrect: want %d, got %d", dirent.Name, stat.Ino, dirent.Ino) + } + if uint32(dirent.DevMinor) != stat.DevMinor { + t.Errorf("dirent dev minor of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMinor, dirent.DevMinor) + } + if uint32(dirent.DevMajor) != stat.DevMajor { + t.Errorf("dirent dev major of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMajor, dirent.DevMajor) + } + } +} diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 0b961d3d9..6358ad8e9 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -384,6 +384,14 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error { return descriptor.verify(params.Expected, params.HashAlgorithms) } +// cachedHashes stores verified hashes from a previous hash step. +type cachedHashes struct { + // offset is the offset of cached hash in each level. + offset []int64 + // hash is the verified cache for each level from previous hash steps. + hash [][]byte +} + // Verify verifies the content read from data with offset. The content is // verified against tree. If content spans across multiple blocks, each block is // verified. Verification fails if the hash of the data does not match the tree @@ -409,29 +417,32 @@ func Verify(params *VerifyParams) (int64, error) { firstDataBlock := params.ReadOffset / layout.blockSize lastDataBlock := (params.ReadOffset + params.ReadSize - 1) / layout.blockSize - buf := make([]byte, layout.blockSize) - var readErr error - total := int64(0) + size := (lastDataBlock - firstDataBlock + 1) * layout.blockSize + retBuf := make([]byte, size) + n, err := params.File.ReadAt(retBuf, firstDataBlock*layout.blockSize) + if err != nil && err != io.EOF { + return 0, err + } + total := int64(n) + bytesRead := int64(0) + + // Only cache hash results if reading more than a block. + var ch *cachedHashes + if lastDataBlock > firstDataBlock { + ch = &cachedHashes{ + offset: make([]int64, layout.numLevels()), + hash: make([][]byte, layout.numLevels()), + } + } for i := firstDataBlock; i <= lastDataBlock; i++ { + // Reach the end of file during verification. + if total <= 0 { + return bytesRead, io.EOF + } // Read a block that includes all or part of target range in // input data. - bytesRead, err := params.File.ReadAt(buf, i*layout.blockSize) - readErr = err - // If at the end of input data and all previous blocks are - // verified, return the verified input data and EOF. - if readErr == io.EOF && bytesRead == 0 { - break - } - if readErr != nil && readErr != io.EOF { - return 0, fmt.Errorf("read from data failed: %w", err) - } - // If this is the end of file, zero the remaining bytes in buf, - // otherwise they are still from the previous block. - if bytesRead < len(buf) { - for j := bytesRead; j < len(buf); j++ { - buf[j] = 0 - } - } + buf := retBuf[(i-firstDataBlock)*layout.blockSize : (i-firstDataBlock+1)*layout.blockSize] + descriptor := VerityDescriptor{ Name: params.Name, FileSize: params.Size, @@ -441,8 +452,8 @@ func Verify(params *VerifyParams) (int64, error) { SymlinkTarget: params.SymlinkTarget, Children: params.Children, } - if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil { - return 0, err + if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected, ch); err != nil { + return bytesRead, err } // startOff is the beginning of the read range within the @@ -459,22 +470,24 @@ func Verify(params *VerifyParams) (int64, error) { if i == lastDataBlock { endOff = (params.ReadOffset+params.ReadSize-1)%layout.blockSize + 1 } + // If the provided size exceeds the end of input data, we should // only copy the parts in buf that's part of input data. - if startOff > int64(bytesRead) { - startOff = int64(bytesRead) + if startOff > total { + startOff = total } - if endOff > int64(bytesRead) { - endOff = int64(bytesRead) + if endOff > total { + endOff = total } + n, err := params.Out.Write(buf[startOff:endOff]) if err != nil { - return total, err + return bytesRead, err } - total += int64(n) - + bytesRead += int64(n) + total -= endOff } - return total, readErr + return bytesRead, nil } // verifyBlock verifies a block against tree. index is the number of block in @@ -482,7 +495,7 @@ func Verify(params *VerifyParams) (int64, error) { // fails if the calculated hash from block is different from any level of // hashes stored in tree. And the final root hash is compared with // expected. -func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error { +func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte, ch *cachedHashes) error { if len(dataBlock) != int(layout.blockSize) { return fmt.Errorf("incorrect block size") } @@ -491,6 +504,12 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, treeBlock := make([]byte, layout.blockSize) var digest []byte for level := 0; level < layout.numLevels(); level++ { + // No need to verify remaining levels if the current block has + // been verified in a previous call and cached. + if ch != nil && ch.offset[level] == layout.digestOffset(level, blockIndex) && ch.hash[level] != nil { + break + } + // Calculate hash. if level == 0 { h, err := hashData(dataBlock, hashAlgorithms) @@ -521,11 +540,19 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, if !bytes.Equal(digest, expectedDigest) { return fmt.Errorf("verification failed") } + if ch != nil { + ch.offset[level] = layout.digestOffset(level, blockIndex) + ch.hash[level] = expectedDigest + } blockIndex = blockIndex / layout.hashesPerBlock() } // Verification for the tree succeeded. Now hash the descriptor with // the root hash and compare it with expected. - descriptor.RootHash = digest + if ch != nil { + descriptor.RootHash = ch.hash[layout.rootLevel()] + } else { + descriptor.RootHash = digest + } return descriptor.verify(expected, hashAlgorithms) } diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 764f1f970..d618da820 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -115,7 +115,7 @@ type Client struct { // channels is the set of all initialized channels. channels []*channel - // availableChannels is a FIFO of inactive channels. + // availableChannels is a LIFO of inactive channels. availableChannels []*channel // -- below corresponds to sendRecvLegacy -- @@ -528,7 +528,7 @@ func (c *Client) sendRecvChannel(t message, r message) error { } // Send the request and receive the server's response. - rsz, err := ch.send(t) + rsz, err := ch.send(t, false /* isServer */) if err != nil { // See above. c.channelsMu.Lock() diff --git a/pkg/p9/file.go b/pkg/p9/file.go index 97e0231d6..8d6af2d6b 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -325,6 +325,12 @@ func (*DisallowServerCalls) Renamed(File, string) { func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { stats := make([]FullStat, 0, len(names)) parent := start + closeParent := func() { + if parent != start { + _ = parent.Close() + } + } + defer closeParent() mask := AttrMaskAll() for i, name := range names { if len(name) == 0 && i == 0 { @@ -340,15 +346,14 @@ func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { continue } qids, child, valid, attr, err := parent.WalkGetAttr([]string{name}) - if parent != start { - _ = parent.Close() - } if err != nil { if errors.Is(err, unix.ENOENT) { return stats, nil } return nil, err } + closeParent() + parent = child stats = append(stats, FullStat{ QID: qids[0], Valid: valid, @@ -357,13 +362,8 @@ func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { if attr.Mode.FileType() != ModeDirectory { // Doesn't need to continue if entry is not a dir. Including symlinks // that cannot be followed. - _ = child.Close() break } - parent = child - } - if parent != start { - _ = parent.Close() } return stats, nil } diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 161b451cc..a8f8a9d03 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -45,6 +45,8 @@ func ExtractErrno(err error) unix.Errno { // Attempt to unwrap. switch e := err.(type) { + case *errors.Error: + return unix.Errno(e.Errno()) case unix.Errno: return e case *os.PathError: diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index 802254a90..69a9f2537 100644 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go @@ -85,7 +85,7 @@ func (ch *channel) service(cs *connState) error { } r := cs.handle(m) msgRegistry.put(m) - rsz, err = ch.send(r) + rsz, err = ch.send(r, true /* isServer */) if err != nil { return err } @@ -122,7 +122,7 @@ func (ch *channel) Close() error { // // The return value is the size of the received response. Not that in the // server case, this is the size of the next request. -func (ch *channel) send(m message) (uint32, error) { +func (ch *channel) send(m message, isServer bool) (uint32, error) { if log.IsLogging(log.Debug) { log.Debugf("send [channel @%p] %s", ch, m.String()) } @@ -162,7 +162,11 @@ func (ch *channel) send(m message) (uint32, error) { } // Perform the one-shot communication. - return ch.data.SendRecv(ssz) + if isServer { + return ch.data.SendRecv(ssz) + } + // RPCs are expected to return quickly rather than block. + return ch.data.SendRecvFast(ssz) } // recv decodes a message that exists on the channel. diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go index b6e2012e8..38ce9be1e 100644 --- a/pkg/ring0/defs.go +++ b/pkg/ring0/defs.go @@ -77,6 +77,9 @@ type CPU struct { // calls and exceptions via the Registers function. registers arch.Registers + // floatingPointState holds floating point state. + floatingPointState fpu.State + // hooks are kernel hooks. hooks Hooks } @@ -90,6 +93,15 @@ func (c *CPU) Registers() *arch.Registers { return &c.registers } +// FloatingPointState returns the kernel floating point state. +// +// This is explicitly safe to call during KernelException and KernelSyscall. +// +//go:nosplit +func (c *CPU) FloatingPointState() *fpu.State { + return &c.floatingPointState +} + // SwitchOpts are passed to the Switch function. type SwitchOpts struct { // Registers are the user register state. diff --git a/pkg/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go index 24f6e4cde..81e90dbf7 100644 --- a/pkg/ring0/defs_amd64.go +++ b/pkg/ring0/defs_amd64.go @@ -116,6 +116,11 @@ type CPUArchState struct { errorType uintptr *kernelEntry + + // Copies of global variables, stored in CPU so that they can be used by + // syscall and exception handlers (in the upper address space). + hasXSAVE bool + hasXSAVEOPT bool } // ErrorCode returns the last error code. diff --git a/pkg/ring0/entry_amd64.go b/pkg/ring0/entry_amd64.go index afd646b0b..13ad4e4df 100644 --- a/pkg/ring0/entry_amd64.go +++ b/pkg/ring0/entry_amd64.go @@ -39,11 +39,6 @@ func sysenter() // assembly to get the ABI0 (i.e., primary) address. func addrOfSysenter() uintptr -// swapgs swaps the current GS value. -// -// This must be called prior to sysret/iret. -func swapgs() - // jumpToKernel jumps to the kernel version of the current RIP. func jumpToKernel() diff --git a/pkg/ring0/entry_amd64.s b/pkg/ring0/entry_amd64.s index 520bd9f57..d2913f190 100644 --- a/pkg/ring0/entry_amd64.s +++ b/pkg/ring0/entry_amd64.s @@ -142,8 +142,103 @@ TEXT ·jumpToUser(SB),NOSPLIT,$0 MOVQ AX, 0(SP) RET +// See kernel_amd64.go. +// +// The 16-byte frame size is for the saved values of MXCSR and the x87 control +// word. +TEXT ·doSwitchToUser(SB),NOSPLIT,$16-48 + // We are passed pointers to heap objects, but do not store them in our + // local frame. + NO_LOCAL_POINTERS + + // MXCSR and the x87 control word are the only floating point state + // that is callee-save and thus we must save. + STMXCSR mxcsr-0(SP) + FSTCW cw-8(SP) + + // Restore application floating point state. + MOVQ cpu+0(FP), SI + MOVQ fpState+16(FP), DI + MOVB ·hasXSAVE(SB), BX + TESTB BX, BX + JZ no_xrstor + // Use xrstor to restore all available fp state. For now, we restore + // everything unconditionally by setting the implicit operand edx:eax + // (the "requested feature bitmap") to all 1's. + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f // XRSTOR64 0(DI) + JMP fprestore_done +no_xrstor: + // Fall back to fxrstor if xsave is not available. + FXRSTOR64 0(DI) +fprestore_done: + + // Set application GS. + MOVQ regs+8(FP), R8 + SWAP_GS() + MOVQ PTRACE_GS_BASE(R8), AX + PUSHQ AX + CALL ·writeGS(SB) + POPQ AX + + // Call sysret() or iret(). + MOVQ userCR3+24(FP), CX + MOVQ needIRET+32(FP), R9 + ADDQ $-32, SP + MOVQ SI, 0(SP) // cpu + MOVQ R8, 8(SP) // regs + MOVQ CX, 16(SP) // userCR3 + TESTQ R9, R9 + JNZ do_iret + CALL ·sysret(SB) + JMP done_sysret_or_iret +do_iret: + CALL ·iret(SB) +done_sysret_or_iret: + MOVQ 24(SP), AX // vector + ADDQ $32, SP + MOVQ AX, vector+40(FP) + + // Save application floating point state. + MOVQ fpState+16(FP), DI + MOVB ·hasXSAVE(SB), BX + MOVB ·hasXSAVEOPT(SB), CX + TESTB BX, BX + JZ no_xsave + // Use xsave/xsaveopt to save all extended state. + // We save everything unconditionally by setting RFBM to all 1's. + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + TESTB CX, CX + JZ no_xsaveopt + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI) + JMP fpsave_done +no_xsaveopt: + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI) + JMP fpsave_done +no_xsave: + FXSAVE64 0(DI) +fpsave_done: + + // Restore MXCSR and the x87 control word after one of the two floating + // point save cases above, to ensure the application versions are saved + // before being clobbered here. + LDMXCSR mxcsr-0(SP) + + // FLDCW is a "waiting" x87 instruction, meaning it checks for pending + // unmasked exceptions before executing. Thus if userspace has unmasked + // an exception and has one pending, it can be raised by FLDCW even + // though the new control word will mask exceptions. To prevent this, + // we must first clear pending exceptions (which will be restored by + // XRSTOR, et al). + BYTE $0xDB; BYTE $0xE2; // FNCLEX + FLDCW cw-8(SP) + + RET + // See entry_amd64.go. -TEXT ·sysret(SB),NOSPLIT,$0-24 +TEXT ·sysret(SB),NOSPLIT,$0-32 // Set application FS. We can't do this in Go because Go code needs FS. MOVQ regs+8(FP), AX MOVQ PTRACE_FS_BASE(AX), AX @@ -182,9 +277,11 @@ TEXT ·sysret(SB),NOSPLIT,$0-24 POPQ AX // Restore AX. POPQ SP // Restore SP. SYSRET64() + // sysenter or exception will write our return value and return to our + // caller. // See entry_amd64.go. -TEXT ·iret(SB),NOSPLIT,$0-24 +TEXT ·iret(SB),NOSPLIT,$0-32 // Set application FS. We can't do this in Go because Go code needs FS. MOVQ regs+8(FP), AX MOVQ PTRACE_FS_BASE(AX), AX @@ -220,6 +317,8 @@ TEXT ·iret(SB),NOSPLIT,$0-24 WRITE_CR3() // Switch to userCR3. POPQ AX // Restore AX. IRET() + // sysenter or exception will write our return value and return to our + // caller. // See entry_amd64.go. TEXT ·resume(SB),NOSPLIT,$0 @@ -324,11 +423,39 @@ kernel: MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. + // Save floating point state. CPU.floatingPointState is a slice, so the + // first word of CPU.floatingPointState is a pointer to the destination + // array. + MOVQ CPU_FPU_STATE(AX), DI + MOVB CPU_HAS_XSAVE(AX), BX + MOVB CPU_HAS_XSAVEOPT(AX), CX + TESTB BX, BX + JZ no_xsave + // Use xsave/xsaveopt to save all extended state. + // We save everything unconditionally by setting RFBM to all 1's. + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + TESTB CX, CX + JZ no_xsaveopt + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI) + JMP fpsave_done +no_xsaveopt: + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI) + JMP fpsave_done +no_xsave: + FXSAVE64 0(DI) +fpsave_done: + // Call the syscall trampoline. LOAD_KERNEL_STACK(GS) - PUSHQ AX // First argument (vCPU). - CALL ·kernelSyscall(SB) // Call the trampoline. - POPQ AX // Pop vCPU. + MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU. + PUSHQ AX // First argument (vCPU). + CALL ·kernelSyscall(SB) // Call the trampoline. + POPQ AX // Pop vCPU. + + // We only trigger a bluepill entry in the bluepill function, and can + // therefore be guaranteed that there is no floating point state to be + // loaded on resuming from halt. JMP ·resume(SB) ADDR_OF_FUNC(·addrOfSysenter(SB), ·sysenter(SB)); @@ -416,15 +543,43 @@ kernel: MOVQ 8(SP), BX // Load the error code. MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU. MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. - MOVQ 0(SP), BX // BX contains the vector. + + // Save floating point state. CPU.floatingPointState is a slice, so the + // first word of CPU.floatingPointState is a pointer to the destination + // array. + MOVQ CPU_FPU_STATE(AX), DI + MOVB CPU_HAS_XSAVE(AX), BX + MOVB CPU_HAS_XSAVEOPT(AX), CX + TESTB BX, BX + JZ no_xsave + // Use xsave/xsaveopt to save all extended state. + // We save everything unconditionally by setting RFBM to all 1's. + MOVL $0xffffffff, AX + MOVL $0xffffffff, DX + TESTB CX, CX + JZ no_xsaveopt + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI) + JMP fpsave_done +no_xsaveopt: + BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI) + JMP fpsave_done +no_xsave: + FXSAVE64 0(DI) +fpsave_done: // Call the exception trampoline. + MOVQ 0(SP), BX // BX contains the vector. LOAD_KERNEL_STACK(GS) - PUSHQ BX // Second argument (vector). - PUSHQ AX // First argument (vCPU). - CALL ·kernelException(SB) // Call the trampoline. - POPQ BX // Pop vector. - POPQ AX // Pop vCPU. + MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU. + PUSHQ BX // Second argument (vector). + PUSHQ AX // First argument (vCPU). + CALL ·kernelException(SB) // Call the trampoline. + POPQ BX // Pop vector. + POPQ AX // Pop vCPU. + + // We only trigger a bluepill entry in the bluepill function, and can + // therefore be guaranteed that there is no floating point state to be + // loaded on resuming from halt. JMP ·resume(SB) #define EXCEPTION_WITH_ERROR(value, symbol, addr) \ diff --git a/pkg/ring0/kernel.go b/pkg/ring0/kernel.go index 292f9d0cc..e7dd84929 100644 --- a/pkg/ring0/kernel.go +++ b/pkg/ring0/kernel.go @@ -14,6 +14,10 @@ package ring0 +import ( + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" +) + // Init initializes a new kernel. // //go:nosplit @@ -80,6 +84,7 @@ func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) { c.self = c // Set self reference. c.kernel = k // Set kernel reference. c.init(cpuID) // Perform architectural init. + c.floatingPointState = fpu.NewState() // Require hooks. if hooks != nil { diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 4a4c0ae26..7e55011b5 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -143,6 +143,9 @@ func (c *CPU) init(cpuID int) { // Set mandatory flags. c.registers.Eflags = KernelFlagsSet + + c.hasXSAVE = hasXSAVE + c.hasXSAVEOPT = hasXSAVEOPT } // StackTop returns the kernel's stack address. @@ -248,19 +251,21 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Ss = uint64(Udata) // Ditto. // Perform the switch. - swapgs() // GS will be swapped on return. - WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. - LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point. + needIRET := uint64(0) if switchOpts.FullRestore { - vector = iret(c, regs, uintptr(userCR3)) - } else { - vector = sysret(c, regs, uintptr(userCR3)) + needIRET = 1 } - SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point. - RestoreKernelFPState() // escapes: no. Restore kernel MXCSR. + vector = doSwitchToUser(c, regs, switchOpts.FloatingPointState.BytePointer(), userCR3, needIRET) // escapes: no. return } +func doSwitchToUser( + cpu *CPU, // +0(FP) + regs *arch.Registers, // +8(FP) + fpState *byte, // +16(FP) + userCR3 uint64, // +24(FP) + needIRET uint64) Vector // +32(FP), +40(FP) + var ( sentryXCR0 uintptr sentryXCR0Once sync.Once @@ -287,7 +292,7 @@ func initSentryXCR0() { //go:nosplit func startGo(c *CPU) { // Save per-cpu. - WriteGS(kernelAddr(c.kernelEntry)) + writeGS(kernelAddr(c.kernelEntry)) // // TODO(mpratt): Note that per the note above, this should be done diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go index 05c394ff5..c42a5b205 100644 --- a/pkg/ring0/lib_amd64.go +++ b/pkg/ring0/lib_amd64.go @@ -21,29 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" ) -// LoadFloatingPoint loads floating point state by the most efficient mechanism -// available (set by Init). -var LoadFloatingPoint func(*byte) - -// SaveFloatingPoint saves floating point state by the most efficient mechanism -// available (set by Init). -var SaveFloatingPoint func(*byte) - -// fxrstor uses fxrstor64 to load floating point state. -func fxrstor(*byte) - -// xrstor uses xrstor to load floating point state. -func xrstor(*byte) - -// fxsave uses fxsave64 to save floating point state. -func fxsave(*byte) - -// xsave uses xsave to save floating point state. -func xsave(*byte) - -// xsaveopt uses xsaveopt to save floating point state. -func xsaveopt(*byte) - // writeFS sets the FS base address (selects one of wrfsbase or wrfsmsr). func writeFS(addr uintptr) @@ -53,8 +30,8 @@ func wrfsbase(addr uintptr) // wrfsmsr writes to the GS_BASE MSR. func wrfsmsr(addr uintptr) -// WriteGS sets the GS address (set by init). -var WriteGS func(addr uintptr) +// writeGS sets the GS address (selects one of wrgsbase or wrgsmsr). +func writeGS(addr uintptr) // wrgsbase writes to the GS base address. func wrgsbase(addr uintptr) @@ -106,19 +83,4 @@ func Init(featureSet *cpuid.FeatureSet) { hasXSAVE = featureSet.UseXsave() hasFSGSBASE = featureSet.HasFeature(cpuid.X86FeatureFSGSBase) validXCR0Mask = uintptr(featureSet.ValidXCR0Mask()) - if hasXSAVEOPT { - SaveFloatingPoint = xsaveopt - LoadFloatingPoint = xrstor - } else if hasXSAVE { - SaveFloatingPoint = xsave - LoadFloatingPoint = xrstor - } else { - SaveFloatingPoint = fxsave - LoadFloatingPoint = fxrstor - } - if hasFSGSBASE { - WriteGS = wrgsbase - } else { - WriteGS = wrgsmsr - } } diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s index 8ed98fc84..0f283aaae 100644 --- a/pkg/ring0/lib_amd64.s +++ b/pkg/ring0/lib_amd64.s @@ -128,6 +128,29 @@ TEXT ·wrfsmsr(SB),NOSPLIT,$0-8 BYTE $0x0f; BYTE $0x30; RET +// writeGS writes to the GS base. +// +// This is written in assembly because it must be callable from assembly (ABI0) +// without an intermediate transition to ABIInternal. +// +// Preconditions: must be running in the lower address space, as it accesses +// global data. +TEXT ·writeGS(SB),NOSPLIT,$8-8 + MOVQ addr+0(FP), AX + + CMPB ·hasFSGSBASE(SB), $1 + JNE msr + + PUSHQ AX + CALL ·wrgsbase(SB) + POPQ AX + RET +msr: + PUSHQ AX + CALL ·wrgsmsr(SB) + POPQ AX + RET + // wrgsbase writes to the GS base. // // The code corresponds to: diff --git a/pkg/ring0/offsets_amd64.go b/pkg/ring0/offsets_amd64.go index 75f6218b3..38fe27c35 100644 --- a/pkg/ring0/offsets_amd64.go +++ b/pkg/ring0/offsets_amd64.go @@ -35,6 +35,9 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_HAS_XSAVE 0x%02x\n", reflect.ValueOf(&c.hasXSAVE).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_HAS_XSAVEOPT 0x%02x\n", reflect.ValueOf(&c.hasXSAVEOPT).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_FPU_STATE 0x%02x\n", reflect.ValueOf(&c.floatingPointState).Pointer()-reflect.ValueOf(c).Pointer()) e := &kernelEntry{} fmt.Fprintf(w, "\n// CPU entry offsets.\n") diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 9dac53c80..3f17fba49 100644 --- a/pkg/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -322,12 +322,3 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc func (p *PageTables) MarkReadOnlyShared() { p.readOnlyShared = true } - -// PrefaultRootTable touches the root table page to be sure that its physical -// pages are mapped. -// -//go:nosplit -//go:noinline -func (p *PageTables) PrefaultRootTable() PTE { - return p.root[0] -} diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD index db5787302..2a1602e2b 100644 --- a/pkg/safecopy/BUILD +++ b/pkg/safecopy/BUILD @@ -18,8 +18,9 @@ go_library( ], visibility = ["//:sandbox"], deps = [ - "//pkg/abi/linux", - "//pkg/syserror", + "//pkg/errors", + "//pkg/errors/linuxerr", + "//pkg/sighandling", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go index df63dd5f1..0dd0aea83 100644 --- a/pkg/safecopy/safecopy.go +++ b/pkg/safecopy/safecopy.go @@ -21,7 +21,9 @@ import ( "runtime" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/errors" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/sighandling" ) // SegvError is returned when a safecopy function receives SIGSEGV. @@ -82,7 +84,7 @@ var ( // when we get a SIGSEGV that is not interesting to us. savedSigSegVHandler uintptr - // same a above, but for SIGBUS signals. + // Same as above, but for SIGBUS signals. savedSigBusHandler uintptr ) @@ -131,18 +133,18 @@ func initializeAddresses() { func init() { initializeAddresses() - if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) } - if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) } - syserror.AddErrorUnwrapper(func(e error) (unix.Errno, bool) { + linuxerr.AddErrorUnwrapper(func(e error) (*errors.Error, bool) { switch e.(type) { case SegvError, BusError, AlignmentError: - return unix.EFAULT, true + return linuxerr.EFAULT, true default: - return 0, false + return nil, false } }) } diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index 2365b2c0d..15f84abea 100644 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go @@ -20,7 +20,6 @@ import ( "unsafe" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" ) // maxRegisterSize is the maximum register size used in memcpy and memclr. It @@ -332,39 +331,3 @@ func errorFromFaultSignal(addr uintptr, sig int32) error { panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr)) } } - -// ReplaceSignalHandler replaces the existing signal handler for the provided -// signal with the one that handles faults in safecopy-protected functions. -// -// It stores the value of the previously set handler in previous. -// -// This function will be called on initialization in order to install safecopy -// handlers for appropriate signals. These handlers will call the previous -// handler however, and if this is function is being used externally then the -// same courtesy is expected. -func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { - var sa linux.SigAction - const maskLen = 8 - - // Get the existing signal handler information, and save the current - // handler. Once we replace it, we will use this pointer to fall back to - // it when we receive other signals. - if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { - return e - } - - // Fail if there isn't a previous handler. - if sa.Handler == 0 { - return fmt.Errorf("previous handler for signal %x isn't set", sig) - } - - *previous = uintptr(sa.Handler) - - // Install our own handler. - sa.Handler = uint64(handler) - if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { - return e - } - - return nil -} diff --git a/pkg/safemem/io.go b/pkg/safemem/io.go index f039a5c34..9551ca853 100644 --- a/pkg/safemem/io.go +++ b/pkg/safemem/io.go @@ -207,58 +207,6 @@ func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) { return wbn, buf, rerr } -// FromIOReaderAt implements Reader for an io.ReaderAt. Does not repeatedly -// invoke io.ReaderAt.ReadAt because ReadAt is more strict than Read. A partial -// read indicates an error. This is not thread-safe. -type FromIOReaderAt struct { - ReaderAt io.ReaderAt - Offset int64 -} - -// ReadToBlocks implements Reader.ReadToBlocks. -func (r FromIOReaderAt) ReadToBlocks(dsts BlockSeq) (uint64, error) { - var buf []byte - var done uint64 - for !dsts.IsEmpty() { - dst := dsts.Head() - var n int - var err error - n, buf, err = r.readToBlock(dst, buf) - done += uint64(n) - if n != dst.Len() { - return done, err - } - dsts = dsts.Tail() - if err != nil { - if dsts.IsEmpty() && err == io.EOF { - return done, nil - } - return done, err - } - } - return done, nil -} - -func (r FromIOReaderAt) readToBlock(dst Block, buf []byte) (int, []byte, error) { - // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require - // safecopy. - if !dst.NeedSafecopy() { - n, err := r.ReaderAt.ReadAt(dst.ToSlice(), r.Offset) - r.Offset += int64(n) - return n, buf, err - } - if len(buf) < dst.Len() { - buf = make([]byte, dst.Len()) - } - rn, rerr := r.ReaderAt.ReadAt(buf[:dst.Len()], r.Offset) - r.Offset += int64(rn) - wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) - if wberr != nil { - return wbn, buf, wberr - } - return wbn, buf, rerr -} - // FromIOWriter implements Writer for an io.Writer by repeatedly invoking // io.Writer.Write until it returns an error or partial write. // diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD index 6cdd21b1b..1f371e513 100644 --- a/pkg/sentry/arch/fpu/BUILD +++ b/pkg/sentry/arch/fpu/BUILD @@ -9,6 +9,7 @@ go_library( "fpu_amd64.go", "fpu_amd64.s", "fpu_arm64.go", + "fpu_unsafe.go", ], visibility = ["//:sandbox"], deps = [ diff --git a/pkg/sentry/arch/fpu/fpu.go b/pkg/sentry/arch/fpu/fpu.go index 867d309a3..62bde19d3 100644 --- a/pkg/sentry/arch/fpu/fpu.go +++ b/pkg/sentry/arch/fpu/fpu.go @@ -17,7 +17,6 @@ package fpu import ( "fmt" - "reflect" ) // State represents floating point state. @@ -40,15 +39,3 @@ type ErrLoadingState struct { func (e ErrLoadingState) Error() string { return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supportedFeatures, e.savedFeatures) } - -// alignedBytes returns a slice of size bytes, aligned in memory to the given -// alignment. This is used because we require certain structures to be aligned -// in a specific way (for example, the X86 floating point data). -func alignedBytes(size, alignment uint) []byte { - data := make([]byte, size+alignment-1) - offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment)) - if offset == 0 { - return data[:size:size] - } - return data[alignment-offset:][:size:size] -} diff --git a/pkg/sentry/arch/fpu/fpu_unsafe.go b/pkg/sentry/arch/fpu/fpu_unsafe.go new file mode 100644 index 000000000..c91dc99be --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_unsafe.go @@ -0,0 +1,31 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fpu + +import ( + "unsafe" +) + +// alignedBytes returns a slice of size bytes, aligned in memory to the given +// alignment. This is used because we require certain structures to be aligned +// in a specific way (for example, the X86 floating point data). +func alignedBytes(size, alignment uint) []byte { + data := make([]byte, size+alignment-1) + offset := uint(uintptr(unsafe.Pointer(&data[0])) % uintptr(alignment)) + if offset == 0 { + return data[:size:size] + } + return data[alignment-offset:][:size:size] +} diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 7ee237c9f..cfb33a398 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,17 +1,25 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") package(licenses = ["notice"]) +proto_library( + name = "control", + srcs = ["control.proto"], + visibility = ["//visibility:public"], +) + go_library( name = "control", srcs = [ "control.go", + "events.go", "fs.go", "lifecycle.go", "logging.go", "pprof.go", "proc.go", "state.go", + "usage.go", ], visibility = [ "//:sandbox", @@ -19,6 +27,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/eventchannel", "//pkg/fd", "//pkg/log", "//pkg/sentry/fdimport", @@ -39,6 +48,7 @@ go_library( "//pkg/tcpip/link/sniffer", "//pkg/urpc", "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/control/control.proto b/pkg/sentry/control/control.proto new file mode 100644 index 000000000..72dda3fbc --- /dev/null +++ b/pkg/sentry/control/control.proto @@ -0,0 +1,40 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package gvisor; + +// ControlConfig configures the permission of controls. +message ControlConfig { + // Names for individual control URPC service objects. + // Any new service object that should be given conditional access should be + // named here and conditionally added based on presence in allowed_controls. + enum Endpoint { + UNKNOWN = 0; + EVENTS = 1; + FS = 2; + LIFECYCLE = 3; + LOGGING = 4; + PROFILE = 5; + USAGE = 6; + PROC = 7; + STATE = 8; + DEBUG = 9; + } + + // allowed_controls represents which endpoints may be registered to the + // server. + repeated Endpoint allowed_controls = 1; +} diff --git a/pkg/sentry/control/events.go b/pkg/sentry/control/events.go new file mode 100644 index 000000000..92e437ae7 --- /dev/null +++ b/pkg/sentry/control/events.go @@ -0,0 +1,65 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package control + +import ( + "errors" + "fmt" + + "gvisor.dev/gvisor/pkg/eventchannel" + "gvisor.dev/gvisor/pkg/urpc" +) + +// EventsOpts are the arguments for eventchannel-related commands. +type EventsOpts struct { + urpc.FilePayload +} + +// Events is the control server state for eventchannel-related commands. +type Events struct { + emitter eventchannel.Emitter +} + +// AttachDebugEmitter receives a connected unix domain socket FD from the client +// and establishes it as a new emitter for the sentry eventchannel. Any existing +// emitters are replaced on a subsequent attach. +func (e *Events) AttachDebugEmitter(o *EventsOpts, _ *struct{}) error { + if len(o.FilePayload.Files) < 1 { + return errors.New("no output writer provided") + } + + sock, err := o.ReleaseFD(0) + if err != nil { + return err + } + sockFD := sock.Release() + + // SocketEmitter takes ownership of sockFD. + emitter, err := eventchannel.SocketEmitter(sockFD) + if err != nil { + return fmt.Errorf("failed to create SocketEmitter for FD %d: %v", sockFD, err) + } + + // If there is already a debug emitter, close the old one. + if e.emitter != nil { + e.emitter.Close() + } + + e.emitter = eventchannel.DebugEmitterFrom(emitter) + + // Register the new stream destination. + eventchannel.AddEmitter(e.emitter) + return nil +} diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 2f3664c57..f721b7236 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -26,6 +26,23 @@ import ( "gvisor.dev/gvisor/pkg/urpc" ) +const ( + // DefaultBlockProfileRate is the default profiling rate for block + // profiles. + // + // The default here is 10%, which will record a stacktrace 10% of the + // time when blocking occurs. Since these events should not be super + // frequent, we expect this to achieve a reasonable balance between + // collecting the data we need and imposing a high performance cost + // (e.g. skewing even the CPU profile). + DefaultBlockProfileRate = 10 + + // DefaultMutexProfileRate is the default profiling rate for mutex + // profiles. Like the block rate above, we use a default rate of 10% + // for the same reasons. + DefaultMutexProfileRate = 10 +) + // Profile includes profile-related RPC stubs. It provides a way to // control the built-in runtime profiling facilities. // @@ -175,12 +192,8 @@ func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error { defer p.blockMu.Unlock() // Always set the rate. We then wait to collect a profile at this rate, - // and disable when we're done. Note that the default here is 10%, which - // will record a stacktrace 10% of the time when blocking occurs. Since - // these events should not be super frequent, we expect this to achieve - // a reasonable balance between collecting the data we need and imposing - // a high performance cost (e.g. skewing even the CPU profile). - rate := 10 + // and disable when we're done. + rate := DefaultBlockProfileRate if o.Rate != 0 { rate = o.Rate } @@ -220,9 +233,8 @@ func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error { p.mutexMu.Lock() defer p.mutexMu.Unlock() - // Always set the fraction. Like the block rate above, we use - // a default rate of 10% for the same reasons. - fraction := 10 + // Always set the fraction. + fraction := DefaultMutexProfileRate if o.Fraction != 0 { fraction = o.Fraction } diff --git a/pkg/sentry/control/usage.go b/pkg/sentry/control/usage.go new file mode 100644 index 000000000..cc78d3f45 --- /dev/null +++ b/pkg/sentry/control/usage.go @@ -0,0 +1,183 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package control + +import ( + "fmt" + "os" + "runtime" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/urpc" +) + +// Usage includes usage-related RPC stubs. +type Usage struct { + Kernel *kernel.Kernel +} + +// MemoryUsageOpts contains usage options. +type MemoryUsageOpts struct { + // Full indicates that a full accounting should be done. If Full is not + // specified, then a partial accounting will be done, and Unknown will + // contain a majority of memory. See Collect for more information. + Full bool `json:"Full"` +} + +// MemoryUsage is a memory usage structure. +type MemoryUsage struct { + Unknown uint64 `json:"Unknown"` + System uint64 `json:"System"` + Anonymous uint64 `json:"Anonymous"` + PageCache uint64 `json:"PageCache"` + Mapped uint64 `json:"Mapped"` + Tmpfs uint64 `json:"Tmpfs"` + Ramdiskfs uint64 `json:"Ramdiskfs"` + Total uint64 `json:"Total"` +} + +// MemoryUsageFileOpts contains usage file options. +type MemoryUsageFileOpts struct { + // Version is used to ensure both sides agree on the format of the + // shared memory buffer. + Version uint64 `json:"Version"` +} + +// MemoryUsageFile contains the file handle to the usage file. +type MemoryUsageFile struct { + urpc.FilePayload +} + +// UsageFD returns the file that tracks the memory usage of the application. +func (u *Usage) UsageFD(opts *MemoryUsageFileOpts, out *MemoryUsageFile) error { + // Only support version 1 for now. + if opts.Version != 1 { + return fmt.Errorf("unsupported version requested: %d", opts.Version) + } + + mf := u.Kernel.MemoryFile() + *out = MemoryUsageFile{ + FilePayload: urpc.FilePayload{ + Files: []*os.File{ + usage.MemoryAccounting.File, + mf.File(), + }, + }, + } + + return nil +} + +// Collect returns memory used by the sandboxed application. +func (u *Usage) Collect(opts *MemoryUsageOpts, out *MemoryUsage) error { + if opts.Full { + // Ensure everything is up to date. + if err := u.Kernel.MemoryFile().UpdateUsage(); err != nil { + return err + } + + // Copy out a snapshot. + snapshot, total := usage.MemoryAccounting.Copy() + *out = MemoryUsage{ + System: snapshot.System, + Anonymous: snapshot.Anonymous, + PageCache: snapshot.PageCache, + Mapped: snapshot.Mapped, + Tmpfs: snapshot.Tmpfs, + Ramdiskfs: snapshot.Ramdiskfs, + Total: total, + } + } else { + // Get total usage from the MemoryFile implementation. + total, err := u.Kernel.MemoryFile().TotalUsage() + if err != nil { + return err + } + + // The memory accounting is guaranteed to be accurate only when + // UpdateUsage is called. If UpdateUsage is not called, then only Mapped + // will be up-to-date. + snapshot, _ := usage.MemoryAccounting.Copy() + *out = MemoryUsage{ + Unknown: total, + Mapped: snapshot.Mapped, + Total: total + snapshot.Mapped, + } + + } + + return nil +} + +// UsageReduceOpts contains options to Usage.Reduce(). +type UsageReduceOpts struct { + // If Wait is true, Reduce blocks until all activity initiated by + // Usage.Reduce() has completed. + Wait bool `json:"wait"` +} + +// UsageReduceOutput contains output from Usage.Reduce(). +type UsageReduceOutput struct{} + +// Reduce requests that the sentry attempt to reduce its memory usage. +func (u *Usage) Reduce(opts *UsageReduceOpts, out *UsageReduceOutput) error { + mf := u.Kernel.MemoryFile() + mf.StartEvictions() + if opts.Wait { + mf.WaitForEvictions() + } + return nil +} + +// MemoryUsageRecord contains the mapping and platform memory file. +type MemoryUsageRecord struct { + mmap uintptr + stats *usage.RTMemoryStats + mf os.File +} + +// NewMemoryUsageRecord creates a new MemoryUsageRecord from usageFile and +// platformFile. +func NewMemoryUsageRecord(usageFile, platformFile os.File) (*MemoryUsageRecord, error) { + mmap, _, e := unix.RawSyscall6(unix.SYS_MMAP, 0, usage.RTMemoryStatsSize, unix.PROT_READ, unix.MAP_SHARED, usageFile.Fd(), 0) + if e != 0 { + return nil, fmt.Errorf("mmap returned %d, want 0", e) + } + + m := MemoryUsageRecord{ + mmap: mmap, + stats: usage.RTMemoryStatsPointer(mmap), + mf: platformFile, + } + + runtime.SetFinalizer(&m, finalizer) + return &m, nil +} + +func finalizer(m *MemoryUsageRecord) { + unix.RawSyscall(unix.SYS_MUNMAP, m.mmap, usage.RTMemoryStatsSize, 0) +} + +// Fetch fetches the usage info from a MemoryUsageRecord. +func (m *MemoryUsageRecord) Fetch() (mapped, unknown, total uint64, err error) { + var stat unix.Stat_t + if err := unix.Fstat(int(m.mf.Fd()), &stat); err != nil { + return 0, 0, 0, err + } + fmem := uint64(stat.Blocks) * 512 + return m.stats.RTMapped, fmem, m.stats.RTMapped + fmem, nil +} diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD index 4c8604d58..66b9ed523 100644 --- a/pkg/sentry/devices/memdev/BUILD +++ b/pkg/sentry/devices/memdev/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/rand", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", @@ -23,7 +24,6 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go index fece3e762..fc702c9f6 100644 --- a/pkg/sentry/devices/memdev/full.go +++ b/pkg/sentry/devices/memdev/full.go @@ -16,8 +16,8 @@ package memdev import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -66,12 +66,12 @@ func (fd *fullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.Rea // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *fullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.ENOSPC + return 0, linuxerr.ENOSPC } // Write implements vfs.FileDescriptionImpl.Write. func (fd *fullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.ENOSPC + return 0, linuxerr.ENOSPC } // Seek implements vfs.FileDescriptionImpl.Seek. diff --git a/pkg/sentry/devices/quotedev/BUILD b/pkg/sentry/devices/quotedev/BUILD deleted file mode 100644 index d09214e3e..000000000 --- a/pkg/sentry/devices/quotedev/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "quotedev", - srcs = ["quotedev.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/fsimpl/devtmpfs", - "//pkg/sentry/vfs", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/devices/quotedev/quotedev.go b/pkg/sentry/devices/quotedev/quotedev.go deleted file mode 100644 index 6114cb724..000000000 --- a/pkg/sentry/devices/quotedev/quotedev.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package quotedev implements a vfs.Device for /dev/gvisor_quote. -package quotedev - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -const ( - quoteDevMinor = 0 -) - -// quoteDevice implements vfs.Device for /dev/gvisor_quote -// -// +stateify savable -type quoteDevice struct{} - -// Open implements vfs.Device.Open. -// TODO(b/157161182): Add support for attestation ioctls. -func (quoteDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - return nil, syserror.EIO -} - -// Register registers all devices implemented by this package in vfsObj. -func Register(vfsObj *vfs.VirtualFilesystem) error { - return vfsObj.RegisterDevice(vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, quoteDevice{}, &vfs.RegisterDeviceOptions{ - GroupName: "gvisor_quote", - }) -} - -// CreateDevtmpfsFiles creates device special files in dev representing all -// devices implemented by this package. -func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { - return dev.CreateDeviceFile(ctx, "gvisor_quote", vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, 0666 /* mode */) -} diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD index b4b6ca38a..ab4cd0b33 100644 --- a/pkg/sentry/devices/ttydev/BUILD +++ b/pkg/sentry/devices/ttydev/BUILD @@ -9,8 +9,8 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/vfs", - "//pkg/syserror", ], ) diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go index a287c65ca..29b79b5d6 100644 --- a/pkg/sentry/devices/ttydev/ttydev.go +++ b/pkg/sentry/devices/ttydev/ttydev.go @@ -18,9 +18,9 @@ package ttydev import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -36,7 +36,7 @@ type ttyDevice struct{} // Open implements vfs.Device.Open. func (ttyDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - return nil, syserror.EIO + return nil, linuxerr.EIO } // Register registers all devices implemented by this package in vfsObj. diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 58fe1e77c..4e573d249 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -68,7 +68,6 @@ go_library( "//pkg/sentry/usage", "//pkg/state", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index a8591052c..e48bd4dba 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -195,7 +194,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { attrs, err := next.Inode.overlay.lower.UnstableAttr(ctx) if err != nil { log.Warningf("copy up failed to get lower attributes: %v", err) - return syserror.EIO + return linuxerr.EIO } var childUpperInode *Inode @@ -211,7 +210,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { childFile, err := parentUpper.Create(ctx, root, next.name, FileFlags{Read: true, Write: true}, attrs.Perms) if err != nil { log.Warningf("copy up failed to create file: %v", err) - return syserror.EIO + return linuxerr.EIO } defer childFile.DecRef(ctx) childUpperInode = childFile.Dirent.Inode @@ -219,13 +218,13 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { case Directory: if err := parentUpper.CreateDirectory(ctx, root, next.name, attrs.Perms); err != nil { log.Warningf("copy up failed to create directory: %v", err) - return syserror.EIO + return linuxerr.EIO } childUpper, err := parentUpper.Lookup(ctx, next.name) if err != nil { werr := fmt.Errorf("copy up failed to lookup directory: %v", err) cleanupUpper(ctx, parentUpper, next.name, werr) - return syserror.EIO + return linuxerr.EIO } defer childUpper.DecRef(ctx) childUpperInode = childUpper.Inode @@ -235,17 +234,17 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { link, err := childLower.Readlink(ctx) if err != nil { log.Warningf("copy up failed to read symlink value: %v", err) - return syserror.EIO + return linuxerr.EIO } if err := parentUpper.CreateLink(ctx, root, link, next.name); err != nil { log.Warningf("copy up failed to create symlink: %v", err) - return syserror.EIO + return linuxerr.EIO } childUpper, err := parentUpper.Lookup(ctx, next.name) if err != nil { werr := fmt.Errorf("copy up failed to lookup symlink: %v", err) cleanupUpper(ctx, parentUpper, next.name, werr) - return syserror.EIO + return linuxerr.EIO } defer childUpper.DecRef(ctx) childUpperInode = childUpper.Inode @@ -259,14 +258,14 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { if err := copyAttributesLocked(ctx, childUpperInode, next.Inode.overlay.lower); err != nil { werr := fmt.Errorf("copy up failed to copy up attributes: %v", err) cleanupUpper(ctx, parentUpper, next.name, werr) - return syserror.EIO + return linuxerr.EIO } // Copy the entire file. if err := copyContentsLocked(ctx, childUpperInode, next.Inode.overlay.lower, attrs.Size); err != nil { werr := fmt.Errorf("copy up failed to copy up contents: %v", err) cleanupUpper(ctx, parentUpper, next.name, werr) - return syserror.EIO + return linuxerr.EIO } lowerMappable := next.Inode.overlay.lower.Mappable() @@ -274,7 +273,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { if lowerMappable != nil && upperMappable == nil { werr := fmt.Errorf("copy up failed: cannot ensure memory mapping coherence") cleanupUpper(ctx, parentUpper, next.name, werr) - return syserror.EIO + return linuxerr.EIO } // Propagate memory mappings to the upper Inode. diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index e28a8961b..7baf26b24 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -34,7 +34,6 @@ go_library( "//pkg/sentry/mm", "//pkg/sentry/pgalloc", "//pkg/sentry/socket/netstack", - "//pkg/syserror", "//pkg/tcpip/link/tun", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/fs/dev/full.go b/pkg/sentry/fs/dev/full.go index deb9c6ad8..6f0c1fc68 100644 --- a/pkg/sentry/fs/dev/full.go +++ b/pkg/sentry/fs/dev/full.go @@ -17,9 +17,9 @@ package dev import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -77,5 +77,5 @@ var _ fs.FileOperations = (*fullFileOperations)(nil) // Write implements FileOperations.Write. func (*fullFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, syserror.ENOSPC + return 0, linuxerr.ENOSPC } diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index ad8ff227e..d300a32e0 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) type globalDirentMap struct { @@ -963,7 +962,7 @@ func (d *Dirent) isMountPointLocked() bool { func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err error) { // Did we race with deletion? if atomic.LoadInt32(&d.deleted) != 0 { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } // Refuse to mount a symlink. @@ -998,7 +997,7 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error { // Did we race with deletion? if atomic.LoadInt32(&d.deleted) != 0 { - return syserror.ENOENT + return linuxerr.ENOENT } // Remount our former child in its place. diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index 5c889c861..9f1fe5160 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -22,7 +22,6 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", @@ -46,7 +45,6 @@ go_test( "//pkg/hostarch", "//pkg/sentry/contexttest", "//pkg/sentry/fs", - "//pkg/syserror", "//pkg/usermem", "@com_github_google_uuid//:go_default_library", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index f8a29816b..4370cce33 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -142,7 +141,7 @@ func (p *pipeOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{secio.FullReader{p.file}}) total := int64(bufN) + n if err != nil && isBlockError(err) { - return total, syserror.ErrWouldBlock + return total, linuxerr.ErrWouldBlock } return total, err } @@ -151,13 +150,13 @@ func (p *pipeOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO func (p *pipeOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { n, err := src.CopyInTo(ctx, safemem.FromIOWriter{p.file}) if err != nil && isBlockError(err) { - return n, syserror.ErrWouldBlock + return n, linuxerr.ErrWouldBlock } return n, err } // isBlockError unwraps os errors and checks if they are caused by EAGAIN or -// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock. +// EWOULDBLOCK. This is so they can be transformed into linuxerr.ErrWouldBlock. func isBlockError(err error) bool { if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) { return true diff --git a/pkg/sentry/fs/fdpipe/pipe_opener.go b/pkg/sentry/fs/fdpipe/pipe_opener.go index adda19168..e91e1b5cb 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener.go @@ -21,9 +21,9 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/syserror" ) // NonBlockingOpener is a generic host file opener used to retry opening host @@ -40,7 +40,7 @@ func Open(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags) (fs p := &pipeOpenState{} canceled := false for { - if file, err := p.TryOpen(ctx, opener, flags); err != syserror.ErrWouldBlock { + if file, err := p.TryOpen(ctx, opener, flags); err != linuxerr.ErrWouldBlock { return file, err } @@ -51,7 +51,7 @@ func Open(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags) (fs if p.hostFile != nil { p.hostFile.Close() } - return nil, syserror.ErrInterrupted + return nil, linuxerr.ErrInterrupted } cancel := ctx.SleepStart() @@ -106,13 +106,13 @@ func (p *pipeOpenState) TryOpen(ctx context.Context, opener NonBlockingOpener, f } return newPipeOperations(ctx, opener, flags, f, nil) - // Handle opening O_WRONLY blocking: convert ENXIO to syserror.ErrWouldBlock. + // Handle opening O_WRONLY blocking: convert ENXIO to linuxerr.ErrWouldBlock. // See TryOpenWriteOnly for more details. case flags.Write: return p.TryOpenWriteOnly(ctx, opener) default: - // Handle opening O_RDONLY blocking: convert EOF from read to syserror.ErrWouldBlock. + // Handle opening O_RDONLY blocking: convert EOF from read to linuxerr.ErrWouldBlock. // See TryOpenReadOnly for more details. return p.TryOpenReadOnly(ctx, opener) } @@ -120,7 +120,7 @@ func (p *pipeOpenState) TryOpen(ctx context.Context, opener NonBlockingOpener, f // TryOpenReadOnly tries to open a host pipe read only but only returns a fs.File when // there is a coordinating writer. Call TryOpenReadOnly repeatedly on the same pipeOpenState -// until syserror.ErrWouldBlock is no longer returned. +// until linuxerr.ErrWouldBlock is no longer returned. // // How it works: // @@ -150,7 +150,7 @@ func (p *pipeOpenState) TryOpenReadOnly(ctx context.Context, opener NonBlockingO if n == 0 { // EOF means that we're not ready yet. if rerr == nil || rerr == io.EOF { - return nil, syserror.ErrWouldBlock + return nil, linuxerr.ErrWouldBlock } // Any error that is not EWOULDBLOCK also means we're not // ready yet, and probably never will be ready. In this @@ -175,16 +175,16 @@ func (p *pipeOpenState) TryOpenReadOnly(ctx context.Context, opener NonBlockingO // TryOpenWriteOnly tries to open a host pipe write only but only returns a fs.File when // there is a coordinating reader. Call TryOpenWriteOnly repeatedly on the same pipeOpenState -// until syserror.ErrWouldBlock is no longer returned. +// until linuxerr.ErrWouldBlock is no longer returned. // // How it works: // // Opening a pipe write only will return ENXIO until readers are available. Converts the ENXIO -// to an syserror.ErrWouldBlock, to tell callers to retry. +// to an linuxerr.ErrWouldBlock, to tell callers to retry. func (*pipeOpenState) TryOpenWriteOnly(ctx context.Context, opener NonBlockingOpener) (*pipeOperations, error) { hostFile, err := opener.NonBlockingOpen(ctx, fs.PermMask{Write: true}) if unwrapError(err) == unix.ENXIO { - return nil, syserror.ErrWouldBlock + return nil, linuxerr.ErrWouldBlock } if err != nil { return nil, err diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go index 89d8be741..e1587288e 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -146,18 +145,18 @@ func TestTryOpen(t *testing.T) { err: unix.ENOENT, }, { - desc: "Blocking Write only returns with syserror.ErrWouldBlock", + desc: "Blocking Write only returns with linuxerr.ErrWouldBlock", makePipe: true, flags: fs.FileFlags{Write: true}, expectFile: false, - err: syserror.ErrWouldBlock, + err: linuxerr.ErrWouldBlock, }, { - desc: "Blocking Read only returns with syserror.ErrWouldBlock", + desc: "Blocking Read only returns with linuxerr.ErrWouldBlock", makePipe: true, flags: fs.FileFlags{Read: true}, expectFile: false, - err: syserror.ErrWouldBlock, + err: linuxerr.ErrWouldBlock, }, } { name := pipename() @@ -316,7 +315,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) { // another writer comes along. This means we can open the same pipe write only // with no problems + write to it, given that opener.Open already tried to open // the pipe RDONLY and succeeded, which we know happened if TryOpen returns - // syserror.ErrwouldBlock. + // linuxerr.ErrwouldBlock. // // This simulates the open(RDONLY) <-> open(WRONLY)+write race we care about, but // does not cause our test to be racy (which would be terrible). @@ -328,8 +327,8 @@ func TestCopiedReadAheadBuffer(t *testing.T) { pipeOps.Release(ctx) t.Fatalf("open(%s, %o) got file, want nil", name, unix.O_RDONLY) } - if err != syserror.ErrWouldBlock { - t.Fatalf("open(%s, %o) got error %v, want %v", name, unix.O_RDONLY, err, syserror.ErrWouldBlock) + if err != linuxerr.ErrWouldBlock { + t.Fatalf("open(%s, %o) got error %v, want %v", name, unix.O_RDONLY, err, linuxerr.ErrWouldBlock) } // Then open the same pipe write only and write some bytes to it. The next diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index 4c8905a7e..63900e766 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -238,7 +237,7 @@ func TestPipeRequest(t *testing.T) { context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))}, flags: fs.FileFlags{Read: true}, keepOpenPartner: true, - err: syserror.ErrWouldBlock, + err: linuxerr.ErrWouldBlock, }, { desc: "Writev on pipe from empty buffer returns nil", @@ -410,8 +409,8 @@ func TestPipeReadsAccumulate(t *testing.T) { n, err := p.Read(ctx, file, iov, 0) total := n iov = iov.DropFirst64(n) - if err != syserror.ErrWouldBlock { - t.Fatalf("Readv got error %v, want %v", err, syserror.ErrWouldBlock) + if err != linuxerr.ErrWouldBlock { + t.Fatalf("Readv got error %v, want %v", err, linuxerr.ErrWouldBlock) } // Write a few more bytes to allow us to read more/accumulate. @@ -479,8 +478,8 @@ func TestPipeWritesAccumulate(t *testing.T) { } iov := usermem.BytesIOSequence(writeBuffer) n, err := p.Write(ctx, file, iov, 0) - if err != syserror.ErrWouldBlock { - t.Fatalf("Writev got error %v, want %v", err, syserror.ErrWouldBlock) + if err != linuxerr.ErrWouldBlock { + t.Fatalf("Writev got error %v, want %v", err, linuxerr.ErrWouldBlock) } if n != int64(pipeSize) { t.Fatalf("Writev partial write, got: %v, want %v", n, pipeSize) diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 57f904801..df04f044d 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsmetric" @@ -27,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -195,10 +195,10 @@ func (f *File) EventUnregister(e *waiter.Entry) { // offset to the value returned by f.FileOperations.Seek if the operation // is successful. // -// Returns syserror.ErrInterrupted if seeking was interrupted. +// Returns linuxerr.ErrInterrupted if seeking was interrupted. func (f *File) Seek(ctx context.Context, whence SeekWhence, offset int64) (int64, error) { if !f.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } defer f.mu.Unlock() @@ -217,10 +217,10 @@ func (f *File) Seek(ctx context.Context, whence SeekWhence, offset int64) (int64 // Readdir unconditionally updates the access time on the File's Inode, // see fs/readdir.c:iterate_dir. // -// Returns syserror.ErrInterrupted if reading was interrupted. +// Returns linuxerr.ErrInterrupted if reading was interrupted. func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error { if !f.mu.Lock(ctx) { - return syserror.ErrInterrupted + return linuxerr.ErrInterrupted } defer f.mu.Unlock() @@ -232,13 +232,13 @@ func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error { // Readv calls f.FileOperations.Read with f as the File, advancing the file // offset if f.FileOperations.Read returns bytes read > 0. // -// Returns syserror.ErrInterrupted if reading was interrupted. +// Returns linuxerr.ErrInterrupted if reading was interrupted. func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) { start := fsmetric.StartReadWait() defer fsmetric.FinishReadWait(fsmetric.ReadWait, start) if !f.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } fsmetric.Reads.Increment() @@ -260,7 +260,7 @@ func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64) defer fsmetric.FinishReadWait(fsmetric.ReadWait, start) if !f.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } fsmetric.Reads.Increment() @@ -276,10 +276,10 @@ func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64) // unavoidably racy for network file systems. Writev also truncates src // to avoid overrunning the current file size limit if necessary. // -// Returns syserror.ErrInterrupted if writing was interrupted. +// Returns linuxerr.ErrInterrupted if writing was interrupted. func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error) { if !f.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append) // Handle append mode. @@ -297,7 +297,7 @@ func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error case ok && limit == 0: unlockAppendMu() f.mu.Unlock() - return 0, syserror.ErrExceedsFileSizeLimit + return 0, linuxerr.ErrExceedsFileSizeLimit case ok: src = src.TakeFirst64(limit) } @@ -335,7 +335,7 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64 limit, ok := f.checkLimit(ctx, offset) switch { case ok && limit == 0: - return 0, syserror.ErrExceedsFileSizeLimit + return 0, linuxerr.ErrExceedsFileSizeLimit case ok: src = src.TakeFirst64(limit) } @@ -352,7 +352,7 @@ func (f *File) offsetForAppend(ctx context.Context, offset *int64) error { if err != nil { // This is an odd error, we treat it as evidence that // something is terribly wrong with the filesystem. - return syserror.EIO + return linuxerr.EIO } // Update the offset. @@ -381,10 +381,10 @@ func (f *File) checkLimit(ctx context.Context, offset int64) (int64, bool) { // Fsync calls f.FileOperations.Fsync with f as the File. // -// Returns syserror.ErrInterrupted if syncing was interrupted. +// Returns linuxerr.ErrInterrupted if syncing was interrupted. func (f *File) Fsync(ctx context.Context, start int64, end int64, syncType SyncType) error { if !f.mu.Lock(ctx) { - return syserror.ErrInterrupted + return linuxerr.ErrInterrupted } defer f.mu.Unlock() @@ -393,10 +393,10 @@ func (f *File) Fsync(ctx context.Context, start int64, end int64, syncType SyncT // Flush calls f.FileOperations.Flush with f as the File. // -// Returns syserror.ErrInterrupted if syncing was interrupted. +// Returns linuxerr.ErrInterrupted if syncing was interrupted. func (f *File) Flush(ctx context.Context) error { if !f.mu.Lock(ctx) { - return syserror.ErrInterrupted + return linuxerr.ErrInterrupted } defer f.mu.Unlock() @@ -405,10 +405,10 @@ func (f *File) Flush(ctx context.Context) error { // ConfigureMMap calls f.FileOperations.ConfigureMMap with f as the File. // -// Returns syserror.ErrInterrupted if interrupted. +// Returns linuxerr.ErrInterrupted if interrupted. func (f *File) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { if !f.mu.Lock(ctx) { - return syserror.ErrInterrupted + return linuxerr.ErrInterrupted } defer f.mu.Unlock() @@ -417,10 +417,10 @@ func (f *File) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { // UnstableAttr calls f.FileOperations.UnstableAttr with f as the File. // -// Returns syserror.ErrInterrupted if interrupted. +// Returns linuxerr.ErrInterrupted if interrupted. func (f *File) UnstableAttr(ctx context.Context) (UnstableAttr, error) { if !f.mu.Lock(ctx) { - return UnstableAttr{}, syserror.ErrInterrupted + return UnstableAttr{}, linuxerr.ErrInterrupted } defer f.mu.Unlock() @@ -495,7 +495,7 @@ type lockedReader struct { // Read implements io.Reader.Read. func (r *lockedReader) Read(buf []byte) (int, error) { if r.Ctx.Interrupted() { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset) r.Offset += n @@ -505,7 +505,7 @@ func (r *lockedReader) Read(buf []byte) (int, error) { // ReadAt implements io.Reader.ReadAt. func (r *lockedReader) ReadAt(buf []byte, offset int64) (int, error) { if r.Ctx.Interrupted() { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), offset) return int(n), err @@ -530,7 +530,7 @@ type lockedWriter struct { // Write implements io.Writer.Write. func (w *lockedWriter) Write(buf []byte) (int, error) { if w.Ctx.Interrupted() { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } n, err := w.WriteAt(buf, w.Offset) w.Offset += int64(n) @@ -549,7 +549,7 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { // contract. Enforce that here. for written < len(buf) { if w.Ctx.Interrupted() { - return written, syserror.ErrInterrupted + return written, linuxerr.ErrInterrupted } var n int64 n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written)) diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index 6ec721022..ce47c3907 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -120,7 +120,7 @@ type FileOperations interface { // Files with !FileFlags.Pwrite. // // If only part of src could be written, Write must return an error - // indicating why (e.g. syserror.ErrWouldBlock). + // indicating why (e.g. linuxerr.ErrWouldBlock). // // Write does not check permissions nor flags. // diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 06c07c807..a27dd0b9a 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -16,6 +16,7 @@ package fs import ( "io" + "math" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" @@ -23,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -246,7 +246,7 @@ func (f *overlayFileOperations) onTop(ctx context.Context, file *File, fn func(* // Something very wrong; return a generic filesystem // error to avoid propagating internals. f.upperMu.Unlock() - return syserror.EIO + return linuxerr.EIO } // Save upper file. @@ -361,10 +361,13 @@ func (*overlayFileOperations) ConfigureMMap(ctx context.Context, file *File, opt return linuxerr.ENODEV } - // FIXME(jamieliu): This is a copy/paste of fsutil.GenericConfigureMMap, - // which we can't use because the overlay implementation is in package fs, - // so depending on fs/fsutil would create a circular dependency. Move - // overlay to fs/overlay. + // TODO(gvisor.dev/issue/1624): This is a copy/paste of + // fsutil.GenericConfigureMMap, which we can't use because the overlay + // implementation is in package fs, so depending on fs/fsutil would create + // a circular dependency. VFS2 overlay doesn't have this issue. + if opts.Offset+opts.Length > math.MaxInt64 { + return linuxerr.EOVERFLOW + } opts.Mappable = o opts.MappingIdentity = file file.IncRef() diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 6bf2d51cb..1a59800ea 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -90,7 +90,6 @@ go_library( "//pkg/sentry/usage", "//pkg/state", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 00b3bb29b..38e3ed42d 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -16,13 +16,13 @@ package fsutil import ( "io" + "math" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -211,6 +211,9 @@ func (FileNoMMap) ConfigureMMap(context.Context, *fs.File, *memmap.MMapOpts) err // GenericConfigureMMap implements fs.FileOperations.ConfigureMMap for most // filesystems that support memory mapping. func GenericConfigureMMap(file *fs.File, m memmap.Mappable, opts *memmap.MMapOpts) error { + if opts.Offset+opts.Length > math.MaxInt64 { + return linuxerr.EOVERFLOW + } opts.Mappable = m opts.MappingIdentity = file file.IncRef() @@ -232,12 +235,12 @@ type FileNoSplice struct{} // WriteTo implements fs.FileOperations.WriteTo. func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) { - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS } // ReadFrom implements fs.FileOperations.ReadFrom. func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) { - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS } // DirFileOperations implements most of fs.FileOperations for directories, @@ -255,12 +258,12 @@ type DirFileOperations struct { // Read implements fs.FileOperations.Read func (*DirFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // Write implements fs.FileOperations.Write. func (*DirFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // StaticDirFileOperations implements fs.FileOperations for directories with diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 23528bf25..37ddb1a3c 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -93,7 +93,8 @@ func NewHostFileMapper() *HostFileMapper { func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { f.refsMu.Lock() defer f.refsMu.Unlock() - for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize { + chunkStart := mr.Start &^ chunkMask + for { refs := f.refs[chunkStart] pgs := pagesInChunk(mr, chunkStart) if refs+pgs < refs { @@ -101,6 +102,10 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { panic(fmt.Sprintf("HostFileMapper.IncRefOn(%v): adding %d page references to chunk %#x, which has %d page references", mr, pgs, chunkStart, refs)) } f.refs[chunkStart] = refs + pgs + chunkStart += chunkSize + if chunkStart >= mr.End || chunkStart == 0 { + break + } } } @@ -112,7 +117,8 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { f.refsMu.Lock() defer f.refsMu.Unlock() - for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize { + chunkStart := mr.Start &^ chunkMask + for { refs := f.refs[chunkStart] pgs := pagesInChunk(mr, chunkStart) switch { @@ -128,6 +134,10 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { case refs < pgs: panic(fmt.Sprintf("HostFileMapper.DecRefOn(%v): removing %d page references from chunk %#x, which has %d page references", mr, pgs, chunkStart, refs)) } + chunkStart += chunkSize + if chunkStart >= mr.End || chunkStart == 0 { + break + } } } @@ -161,7 +171,8 @@ func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, if write { prot |= unix.PROT_WRITE } - for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize { + chunkStart := fr.Start &^ chunkMask + for { m, ok := f.mappings[chunkStart] if !ok { addr, _, errno := unix.Syscall6( @@ -201,6 +212,10 @@ func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, endOff = fr.End - chunkStart } fn(f.unsafeBlockFromChunkMapping(m.addr).TakeFirst64(endOff).DropFirst64(startOff)) + chunkStart += chunkSize + if chunkStart >= fr.End || chunkStart == 0 { + break + } } return nil } diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index 7c2de04c1..06a994193 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -167,7 +166,7 @@ func (i *InodeSimpleAttributes) DropLink() { // StatFS implements fs.InodeOperations.StatFS. func (i *InodeSimpleAttributes) StatFS(context.Context) (fs.Info, error) { if i.fsType == 0 { - return fs.Info{}, syserror.ENOSYS + return fs.Info{}, linuxerr.ENOSYS } return fs.Info{Type: i.fsType}, nil } @@ -294,7 +293,7 @@ type InodeNoStatFS struct{} // StatFS implements fs.InodeOperations.StatFS. func (InodeNoStatFS) StatFS(context.Context) (fs.Info, error) { - return fs.Info{}, syserror.ENOSYS + return fs.Info{}, linuxerr.ENOSYS } // InodeStaticFileGetter implements GetFile for a file with static contents. @@ -401,7 +400,7 @@ type InodeIsDirTruncate struct{} // Truncate implements fs.InodeOperations.Truncate. func (InodeIsDirTruncate) Truncate(context.Context, *fs.Inode, int64) error { - return syserror.EISDIR + return linuxerr.EISDIR } // InodeNoopTruncate implements fs.InodeOperations.Truncate as a noop. @@ -425,7 +424,7 @@ type InodeNotOpenable struct{} // GetFile implements fs.InodeOperations.GetFile. func (InodeNotOpenable) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error) { - return nil, syserror.EIO + return nil, linuxerr.EIO } // InodeNotVirtual can be used by Inodes that are not virtual. @@ -529,5 +528,5 @@ type InodeIsDirAllocate struct{} // Allocate implements fs.InodeOperations.Allocate. func (InodeIsDirAllocate) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error { - return syserror.EISDIR + return linuxerr.EISDIR } diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index c08301d19..ee2f287d9 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -26,6 +26,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors", "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/hostarch", @@ -48,7 +49,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/unet", "//pkg/usermem", "//pkg/waiter", @@ -63,10 +63,10 @@ go_test( library = ":gofer", deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/p9", "//pkg/p9/p9test", "//pkg/sentry/contexttest", "//pkg/sentry/fs", - "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 73d80d9b5..62a517cd7 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" @@ -28,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -226,7 +226,7 @@ func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { if fs.IsDir(file.Dirent.Inode.StableAttr) { // Not all remote file systems enforce this so this client does. - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } var ( @@ -294,7 +294,7 @@ func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO if fs.IsDir(file.Dirent.Inode.StableAttr) { // Not all remote file systems enforce this so this client does. f.incrementReadCounters(start) - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) { diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go index 546ee7d04..4924debeb 100644 --- a/pkg/sentry/fs/gofer/gofer_test.go +++ b/pkg/sentry/fs/gofer/gofer_test.go @@ -19,8 +19,8 @@ import ( "testing" "time" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/p9/p9test" "gvisor.dev/gvisor/pkg/sentry/contexttest" @@ -97,7 +97,7 @@ func TestLookup(t *testing.T) { }, { name: "mock Walk fails (function fails)", - want: unix.ENOENT, + want: linuxerr.ENOENT, }, } @@ -123,7 +123,7 @@ func TestLookup(t *testing.T) { var newInodeOperations fs.InodeOperations if dirent != nil { if dirent.IsNegative() { - err = unix.ENOENT + err = linuxerr.ENOENT } else { newInodeOperations = dirent.Inode.InodeOperations } @@ -131,9 +131,11 @@ func TestLookup(t *testing.T) { // Check return values. if err != test.want { + t.Logf("err: %v %T", err, err) t.Errorf("Lookup got err %v, want %v", err, test.want) } if err == nil && newInodeOperations == nil { + t.Logf("err: %v %T", err, err) t.Errorf("Lookup got non-nil err and non-nil node, wanted at least one non-nil") } }) diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 9ff64a8b6..c3856094f 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + gErr "gvisor.dev/gvisor/pkg/errors" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" @@ -32,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/host" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // inodeOperations implements fs.InodeOperations. @@ -719,12 +719,12 @@ func (i *inodeOperations) configureMMap(file *fs.File, opts *memmap.MMapOpts) er } func init() { - syserror.AddErrorUnwrapper(func(err error) (unix.Errno, bool) { + linuxerr.AddErrorUnwrapper(func(err error) (*gErr.Error, bool) { if _, ok := err.(p9.ErrSocket); ok { // Treat as an I/O error. - return unix.EIO, true + return linuxerr.EIO, true } - return 0, false + return nil, false }) } diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 88d83060c..2f8769f1e 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/syserror" ) // maxFilenameLen is the maximum length of a filename. This is dictated by 9P's @@ -60,7 +59,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string if cp.cacheNegativeDirents() { return fs.NewNegativeDirent(name), nil } - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } i.readdirMu.Unlock() } @@ -74,7 +73,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string // is created over it. return fs.NewNegativeDirent(name), nil } - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } return nil, err } @@ -169,7 +168,7 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string hostFile.Close() } unopened.close(ctx) - return nil, syserror.EIO + return nil, linuxerr.EIO } qid := qids[0] diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 24fc6305c..921612e9c 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -52,7 +52,6 @@ go_library( "//pkg/sentry/uniqueid", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/unet", "//pkg/usermem", diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go index 77c08a7ce..1d0d95634 100644 --- a/pkg/sentry/fs/host/file.go +++ b/pkg/sentry/fs/host/file.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -201,7 +200,7 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I writer := fd.NewReadWriter(f.iops.fileState.FD()) n, err := src.CopyInTo(ctx, safemem.FromIOWriter{writer}) if isBlockError(err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } return n, err } @@ -232,7 +231,7 @@ func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO if n != 0 { err = nil } else { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } } return n, err diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 5f6af2067..92d58e3e9 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -220,7 +219,7 @@ func (i *inodeOperations) Release(context.Context) { // Lookup implements fs.InodeOperations.Lookup. func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } // Create implements fs.InodeOperations.Create. @@ -400,7 +399,7 @@ func (i *inodeOperations) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error // StatFS implements fs.InodeOperations.StatFS. func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) { - return fs.Info{}, syserror.ENOSYS + return fs.Info{}, linuxerr.ENOSYS } // AddLink implements fs.InodeOperations.AddLink. diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 6f38b25c3..4e561c5ed 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -327,7 +326,7 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e // If the signal is SIGTTIN, then we are attempting to read // from the TTY. Don't send the signal and return EIO. if sig == linux.SIGTTIN { - return syserror.EIO + return linuxerr.EIO } // Otherwise, we are writing or changing terminal state. This is allowed. @@ -336,7 +335,7 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e // If the process group is an orphan, return EIO. if pg.IsOrphan() { - return syserror.EIO + return linuxerr.EIO } // Otherwise, send the signal to the process group and return ERESTARTSYS. @@ -349,7 +348,7 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e // // Linux ignores the result of kill_pgrp(). _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) - return syserror.ERESTARTSYS + return linuxerr.ERESTARTSYS } // LINT.ThenChange(../../fsimpl/host/tty.go) diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go index e7db79189..f2a33cc14 100644 --- a/pkg/sentry/fs/host/util.go +++ b/pkg/sentry/fs/host/util.go @@ -96,7 +96,7 @@ type dirInfo struct { // LINT.IfChange // isBlockError unwraps os errors and checks if they are caused by EAGAIN or -// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock. +// EWOULDBLOCK. This is so they can be transformed into linuxerr.ErrWouldBlock. func isBlockError(err error) bool { if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) { return true diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index ec204e5cf..2c6b9e9db 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // Inode is a file system object that can be simultaneously referenced by different @@ -357,7 +356,7 @@ func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error // Truncate calls i.InodeOperations.Truncate with i as the Inode. func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error { if IsDir(i.StableAttr) { - return syserror.EISDIR + return linuxerr.EISDIR } if i.overlay != nil { diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go index 98e9fb2b1..0f8022906 100644 --- a/pkg/sentry/fs/inode_operations.go +++ b/pkg/sentry/fs/inode_operations.go @@ -66,7 +66,7 @@ type InodeOperations interface { // // * A nil Dirent and a non-nil error. If the reason that Lookup failed // was because the name does not exist under Inode, then must return - // syserror.ENOENT. + // linuxerr.ENOENT. // // * If name does not exist under dir and the file system wishes this // fact to be cached, a non-nil Dirent containing a nil Inode and a diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index c47b9ce58..21ad7fa69 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/syserror" ) func overlayHasWhiteout(ctx context.Context, parent *Inode, name string) bool { @@ -103,7 +102,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // Upper fs is not OK with a negative Dirent // being cached in the Dirent tree, so don't // return one. - return nil, false, syserror.ENOENT + return nil, false, linuxerr.ENOENT } entry, err := newOverlayEntry(ctx, upperInode, nil, false) if err != nil { @@ -165,7 +164,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name if negativeUpperChild { return NewNegativeDirent(name), false, nil } - return nil, false, syserror.ENOENT + return nil, false, linuxerr.ENOENT } // Did we find a lower Inode? Remember this because we may decide we don't diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index ee28b0f99..51cd6cd37 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -141,7 +140,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i if i.events.Empty() { // Nothing to read yet, tell caller to block. - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } var writeLen int64 @@ -179,7 +178,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i // WriteTo implements FileOperations.WriteTo. func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) { - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS } // Fsync implements FileOperations.Fsync. @@ -189,7 +188,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error { // ReadFrom implements FileOperations.ReadFrom. func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) { - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS } // Flush implements FileOperations.Flush. diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index e6d74b949..bc75ae505 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -50,7 +50,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sync", - "//pkg/syserror", "//pkg/tcpip/header", "//pkg/tcpip/network/ipv4", "//pkg/usermem", diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go index 379429ab2..75dc5d204 100644 --- a/pkg/sentry/fs/proc/exec_args.go +++ b/pkg/sentry/fs/proc/exec_args.go @@ -107,7 +107,7 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen return 0, linuxerr.EINVAL } - m, err := getTaskMM(f.t) + m, err := getTaskMMIncRef(f.t) if err != nil { return 0, err } diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go index e90da225a..e68bb46c0 100644 --- a/pkg/sentry/fs/proc/fds.go +++ b/pkg/sentry/fs/proc/fds.go @@ -20,12 +20,12 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // LINT.IfChange @@ -37,7 +37,7 @@ func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDF n, err := strconv.ParseUint(p, 10, 64) if err != nil { // Not found. - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } var file *fs.File @@ -48,7 +48,7 @@ func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDF } }) if file == nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } return toInode(file, fdFlags), nil } diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 546b57287..b9629c598 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package proc implements a partial in-memory file system for profs. +// Package proc implements a partial in-memory file system for procfs. package proc import ( @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // LINT.IfChange @@ -125,7 +124,7 @@ func (s *self) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { if t := kernel.TaskFromContext(ctx); t != nil { tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) if tgid == 0 { - return "", syserror.ENOENT + return "", linuxerr.ENOENT } return strconv.FormatUint(uint64(tgid), 10), nil } @@ -149,7 +148,7 @@ func (s *threadSelf) Readlink(ctx context.Context, inode *fs.Inode) (string, err tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) tid := s.pidns.IDOfTask(t) if tid == 0 || tgid == 0 { - return "", syserror.ENOENT + return "", linuxerr.ENOENT } return fmt.Sprintf("%d/task/%d", tgid, tid), nil } diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index 085aa6d61..443b9a94c 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -109,6 +109,9 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode "shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))), "shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))), "shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))), + "msgmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNI, 10))), + "msgmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMAX, 10))), + "msgmnb": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNB, 10))), } d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index edd62b857..03f2a882d 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -35,17 +35,30 @@ import ( "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) // LINT.IfChange -// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's -// users count is incremented, and must be decremented by the caller when it is -// no longer in use. -func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) { +// getTaskMM gets the kernel task's MemoryManager. No additional reference is +// taken on mm here. This is safe because MemoryManager.destroy is required to +// leave the MemoryManager in a state where it's still usable as a +// DynamicBytesSource. +func getTaskMM(t *kernel.Task) *mm.MemoryManager { + var tmm *mm.MemoryManager + t.WithMuLocked(func(t *kernel.Task) { + if mm := t.MemoryManager(); mm != nil { + tmm = mm + } + }) + return tmm +} + +// getTaskMMIncRef returns t's MemoryManager. If getTaskMMIncRef succeeds, the +// MemoryManager's users count is incremented, and must be decremented by the +// caller when it is no longer in use. +func getTaskMMIncRef(t *kernel.Task) (*mm.MemoryManager, error) { if t.ExitState() == kernel.TaskExitDead { return nil, linuxerr.ESRCH } @@ -182,7 +195,7 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry tasks := f.t.ThreadGroup().MemberIDs(f.pidns) if len(tasks) == 0 { - return offset, syserror.ENOENT + return offset, linuxerr.ENOENT } if offset == 0 { @@ -234,15 +247,15 @@ var _ fs.FileOperations = (*subtasksFile)(nil) func (s *subtasks) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) { tid, err := strconv.ParseUint(p, 10, 32) if err != nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } task := s.p.pidns.TaskWithID(kernel.ThreadID(tid)) if task == nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } if task.ThreadGroup() != s.t.ThreadGroup() { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } td := s.p.newTaskDir(ctx, task, dir.MountSource, false) @@ -270,21 +283,18 @@ func (e *exe) executable() (file fsbridge.File, err error) { if err := checkTaskState(e.t); err != nil { return nil, err } - e.t.WithMuLocked(func(t *kernel.Task) { - mm := t.MemoryManager() - if mm == nil { - err = linuxerr.EACCES - return - } + mm := getTaskMM(e.t) + if mm == nil { + return nil, linuxerr.EACCES + } - // The MemoryManager may be destroyed, in which case - // MemoryManager.destroy will simply set the executable to nil - // (with locks held). - file = mm.Executable() - if file == nil { - err = linuxerr.ESRCH - } - }) + // The MemoryManager may be destroyed, in which case + // MemoryManager.destroy will simply set the executable to nil + // (with locks held). + file = mm.Executable() + if file == nil { + err = linuxerr.ESRCH + } return } @@ -464,7 +474,7 @@ func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen if dst.NumBytes() == 0 { return 0, nil } - mm, err := getTaskMM(m.t) + mm, err := getTaskMMIncRef(m.t) if err != nil { return 0, nil } @@ -479,7 +489,7 @@ func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen return int64(n), nil } if readErr != nil { - return 0, syserror.EIO + return 0, linuxerr.EIO } return 0, nil } @@ -495,22 +505,9 @@ func newMaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inod return newProcInode(ctx, seqfile.NewSeqFile(ctx, &mapsData{t}), msrc, fs.SpecialFile, t) } -func (md *mapsData) mm() *mm.MemoryManager { - var tmm *mm.MemoryManager - md.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - // No additional reference is taken on mm here. This is safe - // because MemoryManager.destroy is required to leave the - // MemoryManager in a state where it's still usable as a SeqSource. - tmm = mm - } - }) - return tmm -} - // NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. func (md *mapsData) NeedsUpdate(generation int64) bool { - if mm := md.mm(); mm != nil { + if mm := getTaskMM(md.t); mm != nil { return mm.NeedsUpdate(generation) } return true @@ -518,7 +515,7 @@ func (md *mapsData) NeedsUpdate(generation int64) bool { // ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. func (md *mapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { - if mm := md.mm(); mm != nil { + if mm := getTaskMM(md.t); mm != nil { return mm.ReadMapsSeqFileData(ctx, h) } return []seqfile.SeqData{}, 0 @@ -535,22 +532,9 @@ func newSmaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Ino return newProcInode(ctx, seqfile.NewSeqFile(ctx, &smapsData{t}), msrc, fs.SpecialFile, t) } -func (sd *smapsData) mm() *mm.MemoryManager { - var tmm *mm.MemoryManager - sd.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - // No additional reference is taken on mm here. This is safe - // because MemoryManager.destroy is required to leave the - // MemoryManager in a state where it's still usable as a SeqSource. - tmm = mm - } - }) - return tmm -} - // NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. func (sd *smapsData) NeedsUpdate(generation int64) bool { - if mm := sd.mm(); mm != nil { + if mm := getTaskMM(sd.t); mm != nil { return mm.NeedsUpdate(generation) } return true @@ -558,7 +542,7 @@ func (sd *smapsData) NeedsUpdate(generation int64) bool { // ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. func (sd *smapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { - if mm := sd.mm(); mm != nil { + if mm := getTaskMM(sd.t); mm != nil { return mm.ReadSmapsSeqFileData(ctx, h) } return []seqfile.SeqData{}, 0 @@ -628,12 +612,10 @@ func (s *taskStatData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) fmt.Fprintf(&buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime()))) var vss, rss uint64 - s.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) + if mm := getTaskMM(s.t); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } fmt.Fprintf(&buf, "%d %d ", vss, rss/hostarch.PageSize) // rsslim. @@ -678,12 +660,10 @@ func (s *statmData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([ } var vss, rss uint64 - s.t.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) + if mm := getTaskMM(s.t); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } var buf bytes.Buffer fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize) @@ -735,12 +715,13 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( if fdTable := t.FDTable(); fdTable != nil { fds = fdTable.CurrentMaxFDs() } - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - data = mm.VirtualDataSize() - } }) + + if mm := getTaskMM(s.t); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + data = mm.VirtualDataSize() + } fmt.Fprintf(&buf, "FDSize:\t%d\n", fds) fmt.Fprintf(&buf, "VmSize:\t%d kB\n", vss>>10) fmt.Fprintf(&buf, "VmRSS:\t%d kB\n", rss>>10) @@ -926,7 +907,7 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc return 0, linuxerr.EINVAL } - m, err := getTaskMM(f.t) + m, err := getTaskMMIncRef(f.t) if err != nil { return 0, err } diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index b46567cf8..bfff010c5 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -21,7 +21,6 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/socket/unix/transport", "//pkg/sync", - "//pkg/syserror", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index 33023af77..b1fadee7a 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // CreateOps represents operations to create different file types. @@ -284,9 +283,9 @@ func (d *Dir) walkLocked(ctx context.Context, p string) (*fs.Inode, error) { return inode, nil } - // fs.InodeOperations.Lookup returns syserror.ENOENT if p + // fs.InodeOperations.Lookup returns linuxerr.ENOENT if p // does not exist. - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } // createInodeOperationsCommon creates a new child node at this dir by calling diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index fff4befb2..266140f6f 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/syserror" ) // Splice moves data to this file, directly from another. @@ -55,26 +54,26 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, case dst.UniqueID < src.UniqueID: // Acquire dst first. if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } if !src.mu.Lock(ctx) { dst.mu.Unlock() - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } case dst.UniqueID > src.UniqueID: // Acquire src first. if !src.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } if !dst.mu.Lock(ctx) { src.mu.Unlock() - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } case dst.UniqueID == src.UniqueID: // Acquire only one lock; it's the same file. This is a // bit of a edge case, but presumably it's possible. if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } srcLock = false // Only need one unlock. } @@ -84,13 +83,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, case dstLock: // Acquire only dst. if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } opts.DstStart = dst.offset // Safe: locked. case srcLock: // Acquire only src. if !src.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted + return 0, linuxerr.ErrInterrupted } opts.SrcStart = src.offset // Safe: locked. } @@ -108,7 +107,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, limit, ok := dst.checkLimit(ctx, opts.DstStart) switch { case ok && limit == 0: - err = syserror.ErrExceedsFileSizeLimit + err = linuxerr.ErrExceedsFileSizeLimit case ok && limit < opts.Length: opts.Length = limit // Cap the write. } diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index 0148b33cf..e61115932 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -14,7 +14,6 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel/time", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index 093a14c1f..1c8518d71 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -134,7 +133,7 @@ func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.I } return sizeofUint64, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // Write implements fs.FileOperations.Write. diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5933cb67b..9e9dc06f3 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 3242dcb6a..5716e2ee9 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -155,12 +154,12 @@ func (d *dirInodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name str n, err := strconv.ParseUint(name, 10, 32) if err != nil { // Not found. - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } s, ok := d.replicas[uint32(n)] if !ok { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } s.IncRef() @@ -235,7 +234,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e n := d.next if n == math.MaxUint32 { - return nil, syserror.ENOMEM + return nil, linuxerr.ENOMEM } if _, ok := d.replicas[n]; ok { @@ -335,10 +334,10 @@ func (df *dirFileOperations) Readdir(ctx context.Context, file *fs.File, seriali // Read implements FileOperations.Read func (df *dirFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // Write implements FileOperations.Write. func (df *dirFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 3ba02c218..f9fca6d8e 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -20,10 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -193,7 +193,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque } return n, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) { @@ -207,7 +207,7 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ l.replicaWaiter.Notify(waiter.ReadableEvents) return n, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error { @@ -228,7 +228,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ } return n, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) { @@ -242,7 +242,7 @@ func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSeq l.masterWaiter.Notify(waiter.ReadableEvents) return n, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // transformer is a helper interface to make it easier to stateify queue. diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index 11d6c15d0..25d3c887e 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -17,12 +17,12 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,7 +110,7 @@ func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipl defer q.mu.Unlock() if !q.readable { - return 0, false, syserror.ErrWouldBlock + return 0, false, linuxerr.ErrWouldBlock } if dst.NumBytes() > canonMaxBytes { @@ -155,7 +155,7 @@ func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscip room := waitBufMaxBytes - q.waitBufLen // If out of room, return EAGAIN. if room == 0 && copyLen > 0 { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // Cap the size of the wait buffer. if copyLen > room { diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD index 4acc73ee0..23b5508fd 100644 --- a/pkg/sentry/fs/user/BUILD +++ b/pkg/sentry/fs/user/BUILD @@ -19,7 +19,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go index f6eaab2bd..67a9adfd7 100644 --- a/pkg/sentry/fs/user/path.go +++ b/pkg/sentry/fs/user/path.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // ResolveExecutablePath resolves the given executable name given the working @@ -81,7 +80,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s root := fs.RootFromContext(ctx) if root == nil { // Caller has no root. Don't bother traversing anything. - return "", syserror.ENOENT + return "", linuxerr.ENOENT } defer root.DecRef(ctx) for _, p := range paths { @@ -117,7 +116,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s } // Couldn't find it. - return "", syserror.ENOENT + return "", linuxerr.ENOENT } func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) { @@ -156,7 +155,7 @@ func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNam } // Couldn't find it. - return "", syserror.ENOENT + return "", linuxerr.ENOENT } // getPath returns the PATH as a slice of strings given the environment diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD index 4c9c5b344..e5fdcc776 100644 --- a/pkg/sentry/fsimpl/cgroupfs/BUILD +++ b/pkg/sentry/fsimpl/cgroupfs/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/context", "//pkg/coverage", "//pkg/errors/linuxerr", + "//pkg/fspath", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", @@ -43,7 +44,6 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go index 4290ffe0d..71bb0a9c8 100644 --- a/pkg/sentry/fsimpl/cgroupfs/base.go +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -88,7 +88,6 @@ type controller interface { // +stateify savable type cgroupInode struct { dir - fs *filesystem // ts is the list of tasks in this cgroup. The kernel is responsible for // removing tasks from this list before they're destroyed, so any tasks on @@ -102,9 +101,10 @@ var _ kernel.CgroupImpl = (*cgroupInode)(nil) func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode { c := &cgroupInode{ - fs: fs, - ts: make(map[*kernel.Task]struct{}), + dir: dir{fs: fs}, + ts: make(map[*kernel.Task]struct{}), } + c.dir.cgi = c contents := make(map[string]kernfs.Inode) contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c}) @@ -115,8 +115,7 @@ func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credential } c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555)) - c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - c.dir.InitRefs() + c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents)) atomic.AddUint64(&fs.numCgroups, 1) diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go index 22c8b7fda..edc3b50b9 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -32,7 +32,8 @@ // controllers associated with them. // // Since cgroupfs doesn't allow hardlinks, there is a unique mapping between -// cgroupfs dentries and inodes. +// cgroupfs dentries and inodes. Thus, cgroupfs inodes don't need to be ref +// counted and exist until they're unlinked once or the FS is destroyed. // // # Synchronization // @@ -48,10 +49,11 @@ // Lock order: // // kernel.CgroupRegistry.mu -// cgroupfs.filesystem.mu -// kernel.TaskSet.mu -// kernel.Task.mu -// cgroupfs.filesystem.tasksMu. +// kernfs.filesystem.mu +// kernel.TaskSet.mu +// kernel.Task.mu +// cgroupfs.filesystem.tasksMu. +// cgroupfs.dir.OrderedChildren.mu package cgroupfs import ( @@ -63,6 +65,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -108,6 +111,7 @@ type FilesystemType struct{} // +stateify savable type InternalData struct { DefaultControlValues map[string]int64 + InitialCgroupPath string } // filesystem implements vfs.FilesystemImpl and kernel.cgroupFS. @@ -134,6 +138,11 @@ type filesystem struct { numCgroups uint64 // Protected by atomic ops. root *kernfs.Dentry + // effectiveRoot is the initial cgroup new tasks are created in. Unless + // overwritten by internal mount options, root == effectiveRoot. If + // effectiveRoot != root, an extra reference is held on effectiveRoot for + // the lifetime of the filesystem. + effectiveRoot *kernfs.Dentry // tasksMu serializes task membership changes across all cgroups within a // filesystem. @@ -229,6 +238,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs := vfsfs.Impl().(*filesystem) ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID) fs.root.IncRef() + if fs.effectiveRoot != fs.root { + fs.effectiveRoot.IncRef() + } return vfsfs, fs.root.VFSDentry(), nil } @@ -245,8 +257,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt var defaults map[string]int64 if opts.InternalData != nil { - ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults) defaults = opts.InternalData.(*InternalData).DefaultControlValues + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults) } for _, ty := range wantControllers { @@ -286,6 +298,14 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt var rootD kernfs.Dentry rootD.InitRoot(&fs.Filesystem, root) fs.root = &rootD + fs.effectiveRoot = fs.root + + if err := fs.prepareInitialCgroup(ctx, vfsObj, opts); err != nil { + ctx.Warningf("cgroupfs.FilesystemType.GetFilesystem: failed to prepare initial cgroup: %v", err) + rootD.DecRef(ctx) + fs.VFSFilesystem().DecRef(ctx) + return nil, nil, err + } // Register controllers. The registry may be modified concurrently, so if we // get an error, we raced with someone else who registered the same @@ -303,10 +323,47 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return fs.VFSFilesystem(), rootD.VFSDentry(), nil } +// prepareInitialCgroup creates the initial cgroup according to opts. An initial +// cgroup is optional, and if not specified, this function is a no-op. +func (fs *filesystem) prepareInitialCgroup(ctx context.Context, vfsObj *vfs.VirtualFilesystem, opts vfs.GetFilesystemOptions) error { + if opts.InternalData == nil { + return nil + } + initPathStr := opts.InternalData.(*InternalData).InitialCgroupPath + if initPathStr == "" { + return nil + } + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: initial cgroup path: %v", initPathStr) + initPath := fspath.Parse(initPathStr) + if !initPath.Absolute || !initPath.HasComponents() { + ctx.Warningf("cgroupfs.FilesystemType.GetFilesystem: initial cgroup path invalid: %+v", initPath) + return linuxerr.EINVAL + } + + // Have initial cgroup target, create the tree. + cgDir := fs.root.Inode().(*cgroupInode) + for pit := initPath.Begin; pit.Ok(); pit = pit.Next() { + cgDirI, err := cgDir.NewDir(ctx, pit.String(), vfs.MkdirOptions{}) + if err != nil { + return err + } + cgDir = cgDirI.(*cgroupInode) + } + + // Walk to target dentry. + initDentry, err := fs.root.WalkDentryTree(ctx, vfsObj, initPath) + if err != nil { + ctx.Warningf("cgroupfs.FilesystemType.GetFilesystem: initial cgroup dentry not found: %v", err) + return linuxerr.ENOENT + } + fs.effectiveRoot = initDentry // Reference from WalkDentryTree transferred here. + return nil +} + func (fs *filesystem) rootCgroup() kernel.Cgroup { return kernel.Cgroup{ - Dentry: fs.root, - CgroupImpl: fs.root.Inode().(kernel.CgroupImpl), + Dentry: fs.effectiveRoot, + CgroupImpl: fs.effectiveRoot.Inode().(kernel.CgroupImpl), } } @@ -320,6 +377,10 @@ func (fs *filesystem) Release(ctx context.Context) { r.Unregister(fs.hierarchyID) } + if fs.root != fs.effectiveRoot { + fs.effectiveRoot.DecRef(ctx) + } + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) fs.Filesystem.Release(ctx) } @@ -346,15 +407,18 @@ func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error // // +stateify savable type dir struct { - dirRefs + kernfs.InodeNoopRefCount kernfs.InodeAlwaysValid kernfs.InodeAttrs kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir. + kernfs.InodeDirectoryNoNewChildren kernfs.OrderedChildren implStatFS locks vfs.FileLocks + + fs *filesystem // Immutable. + cgi *cgroupInode // Immutable. } // Keep implements kernfs.Inode.Keep. @@ -378,9 +442,100 @@ func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry return fd.VFSFileDescription(), nil } -// DecRef implements kernfs.Inode.DecRef. -func (d *dir) DecRef(ctx context.Context) { - d.dirRefs.DecRef(func() { d.Destroy(ctx) }) +// NewDir implements kernfs.Inode.NewDir. +func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) { + // "Do not accept '\n' to prevent making /proc/<pid>/cgroup unparsable." + // -- Linux, kernel/cgroup.c:cgroup_mkdir(). + if strings.Contains(name, "\n") { + return nil, linuxerr.EINVAL + } + return d.OrderedChildren.Inserter(name, func() kernfs.Inode { + d.IncLinks(1) + return d.fs.newCgroupInode(ctx, auth.CredentialsFromContext(ctx)) + }) +} + +// Rename implements kernfs.Inode.Rename. Cgroupfs only allows renaming of +// cgroup directories, and the rename may only change the name within the same +// parent. See linux, kernel/cgroup.c:cgroup_rename(). +func (d *dir) Rename(ctx context.Context, oldname, newname string, child, dst kernfs.Inode) error { + if _, ok := child.(*cgroupInode); !ok { + // Not a cgroup directory. Control files are backed by different types. + return linuxerr.ENOTDIR + } + + dstCGInode, ok := dst.(*cgroupInode) + if !ok { + // Not a cgroup inode, so definitely can't be *this* inode. + return linuxerr.EIO + } + // Note: We're intentionally comparing addresses, since two different dirs + // could plausibly be identical in memory, but would occupy different + // locations in memory. + if d != &dstCGInode.dir { + // Destination dir is a different cgroup inode. Cross directory renames + // aren't allowed. + return linuxerr.EIO + } + + // Rename moves oldname to newname within d. Proceed. + return d.OrderedChildren.Rename(ctx, oldname, newname, child, dst) +} + +// Unlink implements kernfs.Inode.Unlink. Cgroupfs disallows unlink, as the only +// files in the filesystem are control files, which can't be deleted. +func (d *dir) Unlink(ctx context.Context, name string, child kernfs.Inode) error { + return linuxerr.EPERM +} + +// hasChildrenLocked returns whether the cgroup dir contains any objects that +// prevent it from being deleted. +func (d *dir) hasChildrenLocked() bool { + // Subdirs take a link on the parent, so checks if there are any direct + // children cgroups. Exclude the dir's self link and the link from ".". + if d.InodeAttrs.Links()-2 > 0 { + return true + } + return len(d.cgi.ts) > 0 +} + +// HasChildren implements kernfs.Inode.HasChildren. +// +// The empty check for a cgroupfs directory is unlike a regular directory since +// a cgroupfs directory will always have control files. A cgroupfs directory can +// be deleted if cgroup contains no tasks and has no sub-cgroups. +func (d *dir) HasChildren() bool { + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + return d.hasChildrenLocked() +} + +// RmDir implements kernfs.Inode.RmDir. +func (d *dir) RmDir(ctx context.Context, name string, child kernfs.Inode) error { + // Unlike a normal directory, we need to recheck if d is empty again, since + // vfs/kernfs can't stop tasks from entering or leaving the cgroup. + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + cgi, ok := child.(*cgroupInode) + if !ok { + return linuxerr.ENOTDIR + } + if cgi.dir.hasChildrenLocked() { + return linuxerr.ENOTEMPTY + } + + // Disallow deletion of the effective root cgroup. + if cgi == d.fs.effectiveRoot.Inode().(*cgroupInode) { + ctx.Warningf("Cannot delete initial cgroup for new tasks %q", d.fs.effectiveRoot.FSLocalPath()) + return linuxerr.EBUSY + } + + err := d.OrderedChildren.RmDir(ctx, name, child) + if err == nil { + d.InodeAttrs.DecLinks() + } + return err } // controllerFile represents a generic control file that appears within a cgroup diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go index 485c98376..d880c9bc4 100644 --- a/pkg/sentry/fsimpl/cgroupfs/memory.go +++ b/pkg/sentry/fsimpl/cgroupfs/memory.go @@ -31,22 +31,34 @@ import ( type memoryController struct { controllerCommon - limitBytes int64 + limitBytes int64 + softLimitBytes int64 + moveChargeAtImmigrate int64 } var _ controller = (*memoryController)(nil) func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController { c := &memoryController{ - // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which - // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact - // value isn't very important. - limitBytes: math.MaxInt64, + // Linux sets these limits to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, + // which is ~ 2**63 on a 64-bit system. So essentially, inifinity. The + // exact value isn't very important. + + limitBytes: math.MaxInt64, + softLimitBytes: math.MaxInt64, } - if val, ok := defaults["memory.limit_in_bytes"]; ok { - c.limitBytes = val - delete(defaults, "memory.limit_in_bytes") + + consumeDefault := func(name string, valPtr *int64) { + if val, ok := defaults[name]; ok { + *valPtr = val + delete(defaults, name) + } } + + consumeDefault("memory.limit_in_bytes", &c.limitBytes) + consumeDefault("memory.soft_limit_in_bytes", &c.softLimitBytes) + consumeDefault("memory.move_charge_at_immigrate", &c.moveChargeAtImmigrate) + c.controllerCommon.init(controllerMemory, fs) return c } @@ -55,6 +67,8 @@ func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryContr func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{}) contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes)) + contents["memory.soft_limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.softLimitBytes)) + contents["memory.move_charge_at_immigrate"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.moveChargeAtImmigrate)) } // +stateify savable diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index f981ff296..e0b879339 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -45,7 +45,6 @@ go_library( "//pkg/sentry/unimpl", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index 7a488e9fd..e711debcb 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // Name is the filesystem name. @@ -180,7 +179,7 @@ func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credential i.mu.Lock() defer i.mu.Unlock() if i.nextIdx == math.MaxUint32 { - return nil, syserror.ENOMEM + return nil, linuxerr.ENOMEM } idx := i.nextIdx i.nextIdx++ @@ -241,7 +240,7 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro // Not a static entry. idx, err := strconv.ParseUint(name, 10, 32) if err != nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } i.mu.Lock() defer i.mu.Unlock() @@ -250,7 +249,7 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro return ri, nil } - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } // IterDirents implements kernfs.Inode.IterDirents. diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index 9cb21e83b..609623f9f 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -20,10 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -203,7 +203,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque } else if notifyEcho { l.masterWaiter.Notify(waiter.ReadableEvents) } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) { @@ -220,7 +220,7 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ l.replicaWaiter.Notify(waiter.ReadableEvents) return n, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { @@ -242,7 +242,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ } return n, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) { @@ -257,7 +257,7 @@ func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSeq l.masterWaiter.Notify(waiter.ReadableEvents) return n, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // transformer is a helper interface to make it easier to stateify queue. diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go index ff1d89955..85aeefa43 100644 --- a/pkg/sentry/fsimpl/devpts/queue.go +++ b/pkg/sentry/fsimpl/devpts/queue.go @@ -17,12 +17,12 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,7 +110,7 @@ func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipl defer q.mu.Unlock() if !q.readable { - return 0, false, false, syserror.ErrWouldBlock + return 0, false, false, linuxerr.ErrWouldBlock } if dst.NumBytes() > canonMaxBytes { @@ -156,7 +156,7 @@ func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscip room := waitBufMaxBytes - q.waitBufLen // If out of room, return EAGAIN. if room == 0 && copyLen > 0 { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // Cap the size of the wait buffer. if copyLen > room { diff --git a/pkg/sentry/fsimpl/eventfd/BUILD b/pkg/sentry/fsimpl/eventfd/BUILD index c09fdc7f9..1cb049a29 100644 --- a/pkg/sentry/fsimpl/eventfd/BUILD +++ b/pkg/sentry/fsimpl/eventfd/BUILD @@ -9,11 +9,11 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fdnotifier", "//pkg/hostarch", "//pkg/log", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 4f79cfcb7..af5ba5131 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -22,11 +22,11 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -149,7 +149,7 @@ func (efd *EventFileDescription) hostReadLocked(ctx context.Context, dst usermem var buf [8]byte if _, err := unix.Read(efd.hostfd, buf[:]); err != nil { if err == unix.EWOULDBLOCK { - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } return err } @@ -167,7 +167,7 @@ func (efd *EventFileDescription) read(ctx context.Context, dst usermem.IOSequenc // We can't complete the read if the value is currently zero. if efd.val == 0 { efd.mu.Unlock() - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } // Update the value based on the mode the event is operating in. @@ -200,7 +200,7 @@ func (efd *EventFileDescription) hostWriteLocked(val uint64) error { hostarch.ByteOrder.PutUint64(buf[:], val) _, err := unix.Write(efd.hostfd, buf[:]) if err == unix.EWOULDBLOCK { - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } return err } @@ -232,7 +232,7 @@ func (efd *EventFileDescription) Signal(val uint64) error { // uint64 minus 1. if val > math.MaxUint64-1-efd.val { efd.mu.Unlock() - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } efd.val += val diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/pkg/sentry/fsimpl/ext/BUILD diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 871df5984..05c4fbeb2 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -59,7 +59,6 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", @@ -84,7 +83,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index dab1e779d..0f855ac59 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -38,7 +37,7 @@ type fuseDevice struct{} // Open implements vfs.Device.Open. func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { if !kernel.FUSEEnabled { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } var fd DeviceFD @@ -126,7 +125,7 @@ func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset in return 0, linuxerr.EPERM } - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS } // Read implements vfs.FileDescriptionImpl.Read. @@ -192,7 +191,7 @@ func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts } if req == nil { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // We already checked the size: dst must be able to fit the whole request. @@ -205,7 +204,7 @@ func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts return 0, err } if n != len(req.data) { - return 0, syserror.EIO + return 0, linuxerr.EIO } if req.hdr.Opcode == linux.FUSE_WRITE { @@ -214,7 +213,7 @@ func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts return 0, err } if written != len(req.payload) { - return 0, syserror.EIO + return 0, linuxerr.EIO } n += int(written) } @@ -238,7 +237,7 @@ func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset i return 0, linuxerr.EPERM } - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS } // Write implements vfs.FileDescriptionImpl.Write. @@ -395,7 +394,7 @@ func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64 return 0, linuxerr.EPERM } - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS } // sendResponse sends a response to the waiting task (if any). diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 04250d796..8951b5ba8 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -20,11 +20,11 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -186,7 +186,7 @@ func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem. // "would block". n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{}) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go index fcc5d9a2a..9611edd5a 100644 --- a/pkg/sentry/fsimpl/fuse/directory.go +++ b/pkg/sentry/fsimpl/fuse/directory.go @@ -19,10 +19,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -32,27 +32,27 @@ type directoryFD struct { // Allocate implements directoryFD.Allocate. func (*directoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { - return syserror.EISDIR + return linuxerr.EISDIR } // PRead implements vfs.FileDescriptionImpl.PRead. func (*directoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // Read implements vfs.FileDescriptionImpl.Read. func (*directoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // PWrite implements vfs.FileDescriptionImpl.PWrite. func (*directoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // Write implements vfs.FileDescriptionImpl.Write. func (*directoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 172cbd88f..af16098d2 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -612,7 +611,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo return nil, err } if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { - return nil, syserror.EIO + return nil, linuxerr.EIO } child := i.fs.newInode(ctx, out.NodeID, out.Attr) return child, nil diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 35d0ab6f4..fe119aa43 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" ) // ReadInPages sends FUSE_READ requests for the size after round it up to @@ -221,7 +220,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, // Write more than requested? EIO. if out.Size > toWrite { - return 0, syserror.EIO + return 0, linuxerr.EIO } written += out.Size diff --git a/pkg/sentry/fsimpl/fuse/regular_file.go b/pkg/sentry/fsimpl/fuse/regular_file.go index 6c4de3507..38cde8208 100644 --- a/pkg/sentry/fsimpl/fuse/regular_file.go +++ b/pkg/sentry/fsimpl/fuse/regular_file.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -108,7 +107,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs return 0, err } if int64(cp) != toCopy { - return 0, syserror.EIO + return 0, linuxerr.EIO } copied += toCopy } @@ -205,7 +204,7 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off return 0, offset, err } if int64(cp) != srclen { - return 0, offset, syserror.EIO + return 0, offset, linuxerr.EIO } n, err := fd.inode().fs.Write(ctx, fd, uint64(offset), uint32(srclen), data) @@ -216,7 +215,7 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off if n == 0 { // We have checked srclen != 0 previously. // If err == nil, then it's a short write and we return EIO. - return 0, offset, syserror.EIO + return 0, offset, linuxerr.EIO } written = int64(n) diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 752060044..509dd0e1a 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -54,7 +54,10 @@ go_library( "//pkg/fdnotifier", "//pkg/fspath", "//pkg/hostarch", + "//pkg/lisafs", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/p9", "//pkg/refs", @@ -79,7 +82,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/unet", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 5c48a9fee..d99a6112c 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -222,47 +222,88 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { off := uint64(0) const count = 64 * 1024 // for consistency with the vfs1 client d.handleMu.RLock() - if d.readFile.isNil() { + if !d.isReadFileOk() { // This should not be possible because a readable handle should // have been opened when the calling directoryFD was opened. d.handleMu.RUnlock() panic("gofer.dentry.getDirents called without a readable handle") } + // shouldSeek0 indicates whether the server should SEEK to 0 before reading + // directory entries. + shouldSeek0 := true for { - p9ds, err := d.readFile.readdir(ctx, off, count) - if err != nil { - d.handleMu.RUnlock() - return nil, err - } - if len(p9ds) == 0 { - d.handleMu.RUnlock() - break - } - for _, p9d := range p9ds { - if p9d.Name == "." || p9d.Name == ".." { - continue + if d.fs.opts.lisaEnabled { + countLisa := int32(count) + if shouldSeek0 { + // See lisafs.Getdents64Req.Count. + countLisa = -countLisa + shouldSeek0 = false + } + lisafsDs, err := d.readFDLisa.Getdents64(ctx, countLisa) + if err != nil { + d.handleMu.RUnlock() + return nil, err + } + if len(lisafsDs) == 0 { + d.handleMu.RUnlock() + break + } + for i := range lisafsDs { + name := string(lisafsDs[i].Name) + if name == "." || name == ".." { + continue + } + dirent := vfs.Dirent{ + Name: name, + Ino: d.fs.inoFromKey(inoKey{ + ino: uint64(lisafsDs[i].Ino), + devMinor: uint32(lisafsDs[i].DevMinor), + devMajor: uint32(lisafsDs[i].DevMajor), + }), + NextOff: int64(len(dirents) + 1), + Type: uint8(lisafsDs[i].Type), + } + dirents = append(dirents, dirent) + if realChildren != nil { + realChildren[name] = struct{}{} + } } - dirent := vfs.Dirent{ - Name: p9d.Name, - Ino: d.fs.inoFromQIDPath(p9d.QID.Path), - NextOff: int64(len(dirents) + 1), + } else { + p9ds, err := d.readFile.readdir(ctx, off, count) + if err != nil { + d.handleMu.RUnlock() + return nil, err } - // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or - // DMSOCKET. - switch p9d.Type { - case p9.TypeSymlink: - dirent.Type = linux.DT_LNK - case p9.TypeDir: - dirent.Type = linux.DT_DIR - default: - dirent.Type = linux.DT_REG + if len(p9ds) == 0 { + d.handleMu.RUnlock() + break } - dirents = append(dirents, dirent) - if realChildren != nil { - realChildren[p9d.Name] = struct{}{} + for _, p9d := range p9ds { + if p9d.Name == "." || p9d.Name == ".." { + continue + } + dirent := vfs.Dirent{ + Name: p9d.Name, + Ino: d.fs.inoFromQIDPath(p9d.QID.Path), + NextOff: int64(len(dirents) + 1), + } + // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or + // DMSOCKET. + switch p9d.Type { + case p9.TypeSymlink: + dirent.Type = linux.DT_LNK + case p9.TypeDir: + dirent.Type = linux.DT_DIR + default: + dirent.Type = linux.DT_REG + } + dirents = append(dirents, dirent) + if realChildren != nil { + realChildren[p9d.Name] = struct{}{} + } } + off = p9ds[len(p9ds)-1].Offset } - off = p9ds[len(p9ds)-1].Offset } } // Emit entries for synthetic children. diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 05b776c2e..f7b3446d3 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -21,10 +21,12 @@ import ( "sync" "sync/atomic" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" "gvisor.dev/gvisor/pkg/sentry/fsmetric" @@ -33,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -54,9 +55,47 @@ func (fs *filesystem) Sync(ctx context.Context) error { // regardless. var retErr error + if fs.opts.lisaEnabled { + // Try accumulating all FDIDs to fsync and fsync then via one RPC as + // opposed to making an RPC per FDID. Passing a non-nil accFsyncFDIDs to + // dentry.syncCachedFile() and specialFileFD.sync() will cause them to not + // make an RPC, instead accumulate syncable FDIDs in the passed slice. + accFsyncFDIDs := make([]lisafs.FDID, 0, len(ds)+len(sffds)) + + // Sync syncable dentries. + for _, d := range ds { + if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil { + ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) + if retErr == nil { + retErr = err + } + } + } + + // Sync special files, which may be writable but do not use dentry shared + // handles (so they won't be synced by the above). + for _, sffd := range sffds { + if err := sffd.sync(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil { + ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) + if retErr == nil { + retErr = err + } + } + } + + if err := fs.clientLisa.SyncFDs(ctx, accFsyncFDIDs); err != nil { + ctx.Infof("gofer.filesystem.Sync: fs.fsyncMultipleFDLisa failed: %v", err) + if retErr == nil { + retErr = err + } + } + + return retErr + } + // Sync syncable dentries. for _, d := range ds { - if err := d.syncCachedFile(ctx, true /* forFilesystemSync */); err != nil { + if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil { ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) if retErr == nil { retErr = err @@ -67,7 +106,7 @@ func (fs *filesystem) Sync(ctx context.Context) error { // Sync special files, which may be writable but do not use dentry shared // handles (so they won't be synced by the above). for _, sffd := range sffds { - if err := sffd.sync(ctx, true /* forFilesystemSync */); err != nil { + if err := sffd.sync(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil { ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) if retErr == nil { retErr = err @@ -198,7 +237,13 @@ afterSymlink: rp.Advance() return d.parent, followedSymlink, nil } - child, err := fs.getChildLocked(ctx, d, name, ds) + var child *dentry + var err error + if fs.opts.lisaEnabled { + child, err = fs.getChildAndWalkPathLocked(ctx, d, rp, ds) + } else { + child, err = fs.getChildLocked(ctx, d, name, ds) + } if err != nil { return nil, false, err } @@ -220,6 +265,99 @@ afterSymlink: return child, followedSymlink, nil } +// Preconditions: +// * fs.opts.lisaEnabled. +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. +// * parent.isDir(). +// * parent and the dentry at name have been revalidated. +func (fs *filesystem) getChildAndWalkPathLocked(ctx context.Context, parent *dentry, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { + // Note that pit is a copy of the iterator that does not affect rp. + pit := rp.Pit() + first := pit.String() + if len(first) > maxFilenameLen { + return nil, linuxerr.ENAMETOOLONG + } + if child, ok := parent.children[first]; ok || parent.isSynthetic() { + if child == nil { + return nil, linuxerr.ENOENT + } + return child, nil + } + + // Walk as much of the path as possible in 1 RPC. + names := []string{first} + for pit = pit.Next(); pit.Ok(); pit = pit.Next() { + name := pit.String() + if name == "." { + continue + } + if name == ".." { + break + } + names = append(names, name) + } + status, inodes, err := parent.controlFDLisa.WalkMultiple(ctx, names) + if err != nil { + return nil, err + } + if len(inodes) == 0 { + parent.cacheNegativeLookupLocked(first) + return nil, linuxerr.ENOENT + } + + // Add the walked inodes into the dentry tree. + curParent := parent + curParentDirMuLock := func() { + if curParent != parent { + curParent.dirMu.Lock() + } + } + curParentDirMuUnlock := func() { + if curParent != parent { + curParent.dirMu.Unlock() // +checklocksforce: locked via curParentDirMuLock(). + } + } + var ret *dentry + var dentryCreationErr error + for i := range inodes { + if dentryCreationErr != nil { + fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD) + continue + } + + child, err := fs.newDentryLisa(ctx, &inodes[i]) + if err != nil { + fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD) + dentryCreationErr = err + continue + } + curParentDirMuLock() + curParent.cacheNewChildLocked(child, names[i]) + curParentDirMuUnlock() + // For now, child has 0 references, so our caller should call + // child.checkCachingLocked(). curParent gained a ref so we should also + // call curParent.checkCachingLocked() so it can be removed from the cache + // if needed. We only do that for the first iteration because all + // subsequent parents would have already been added to ds. + if i == 0 { + *ds = appendDentry(*ds, curParent) + } + *ds = appendDentry(*ds, child) + curParent = child + if i == 0 { + ret = child + } + } + + if status == lisafs.WalkComponentDoesNotExist && curParent.isDir() { + curParentDirMuLock() + curParent.cacheNegativeLookupLocked(names[len(inodes)]) + curParentDirMuUnlock() + } + return ret, dentryCreationErr +} + // getChildLocked returns a dentry representing the child of parent with the // given name. Returns ENOENT if the child doesn't exist. // @@ -228,32 +366,47 @@ afterSymlink: // * parent.dirMu must be locked. // * parent.isDir(). // * name is not "." or "..". -// * dentry at name has been revalidated +// * parent and the dentry at name have been revalidated. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if len(name) > maxFilenameLen { return nil, linuxerr.ENAMETOOLONG } if child, ok := parent.children[name]; ok || parent.isSynthetic() { if child == nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } return child, nil } - qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) - if err != nil { - if linuxerr.Equals(linuxerr.ENOENT, err) { - parent.cacheNegativeLookupLocked(name) + var child *dentry + if fs.opts.lisaEnabled { + childInode, err := parent.controlFDLisa.Walk(ctx, name) + if err != nil { + if linuxerr.Equals(linuxerr.ENOENT, err) { + parent.cacheNegativeLookupLocked(name) + } + return nil, err + } + // Create a new dentry representing the file. + child, err = fs.newDentryLisa(ctx, childInode) + if err != nil { + fs.clientLisa.CloseFDBatched(ctx, childInode.ControlFD) + return nil, err + } + } else { + qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) + if err != nil { + if linuxerr.Equals(linuxerr.ENOENT, err) { + parent.cacheNegativeLookupLocked(name) + } + return nil, err + } + // Create a new dentry representing the file. + child, err = fs.newDentry(ctx, file, qid, attrMask, &attr) + if err != nil { + file.close(ctx) + return nil, err } - return nil, err - } - - // Create a new dentry representing the file. - child, err := fs.newDentry(ctx, file, qid, attrMask, &attr) - if err != nil { - file.close(ctx) - delete(parent.children, name) - return nil, err } parent.cacheNewChildLocked(child, name) appendNewChildDentry(ds, parent, child) @@ -329,7 +482,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, // Preconditions: // * !rp.Done(). // * For the final path component in rp, !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) error, createInSyntheticDir func(parent *dentry, name string) error) error { +func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error), createInSyntheticDir func(parent *dentry, name string) error, updateChild func(child *dentry)) error { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) @@ -349,7 +502,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return linuxerr.EEXIST } if parent.isDeleted() { - return syserror.ENOENT + return linuxerr.ENOENT } if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, name, &ds); err != nil { return err @@ -395,7 +548,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return err } if !dir && rp.MustBeDir() { - return syserror.ENOENT + return linuxerr.ENOENT } if parent.isSynthetic() { if createInSyntheticDir == nil { @@ -416,9 +569,26 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir // No cached dentry exists; however, in InteropModeShared there might still be // an existing file at name. Just attempt the file creation RPC anyways. If a // file does exist, the RPC will fail with EEXIST like we would have. - if err := createInRemoteDir(parent, name, &ds); err != nil { + lisaInode, err := createInRemoteDir(parent, name, &ds) + if err != nil { return err } + // lisafs may aggresively cache newly created inodes. This has helped reduce + // Walk RPCs in practice. + if lisaInode != nil { + child, err := fs.newDentryLisa(ctx, lisaInode) + if err != nil { + fs.clientLisa.CloseFDBatched(ctx, lisaInode.ControlFD) + return err + } + parent.cacheNewChildLocked(child, name) + appendNewChildDentry(&ds, parent, child) + + // lisafs may update dentry properties upon successful creation. + if updateChild != nil { + updateChild(child) + } + } if fs.opts.interop != InteropModeShared { if child, ok := parent.children[name]; ok && child == nil { // Delete the now-stale negative dentry. @@ -463,7 +633,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b } } else { if name == "." || name == ".." { - return syserror.EISDIR + return linuxerr.EISDIR } } @@ -486,7 +656,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b child, ok = parent.children[name] if ok && child == nil { // Hit a negative cached entry, child doesn't exist. - return syserror.ENOENT + return linuxerr.ENOENT } } else { child, _, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) @@ -552,7 +722,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // child must be a non-directory file. if child != nil && child.isDir() { vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above. - return syserror.EISDIR + return linuxerr.EISDIR } if rp.MustBeDir() { if child != nil { @@ -563,10 +733,14 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b } if parent.isSynthetic() { if child == nil { - return syserror.ENOENT + return linuxerr.ENOENT } } else if child == nil || !child.isSynthetic() { - err = parent.file.unlinkAt(ctx, name, flags) + if fs.opts.lisaEnabled { + err = parent.controlFDLisa.UnlinkAt(ctx, name, flags) + } else { + err = parent.file.unlinkAt(ctx, name, flags) + } if err != nil { if child != nil { vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above. @@ -659,40 +833,43 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, _ **[]*dentry) error { + err := fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, ds **[]*dentry) (*lisafs.Inode, error) { if rp.Mount() != vd.Mount() { - return linuxerr.EXDEV + return nil, linuxerr.EXDEV } d := vd.Dentry().Impl().(*dentry) if d.isDir() { - return linuxerr.EPERM + return nil, linuxerr.EPERM } gid := auth.KGID(atomic.LoadUint32(&d.gid)) uid := auth.KUID(atomic.LoadUint32(&d.uid)) mode := linux.FileMode(atomic.LoadUint32(&d.mode)) if err := vfs.MayLink(rp.Credentials(), mode, uid, gid); err != nil { - return err + return nil, err } if d.nlink == 0 { - return syserror.ENOENT + return nil, linuxerr.ENOENT } if d.nlink == math.MaxUint32 { - return linuxerr.EMLINK + return nil, linuxerr.EMLINK } - if err := parent.file.link(ctx, d.file, childName); err != nil { - return err + if fs.opts.lisaEnabled { + return parent.controlFDLisa.LinkAt(ctx, d.controlFDLisa.ID(), childName) } + return nil, parent.file.link(ctx, d.file, childName) + }, nil, nil) + if err == nil { // Success! - atomic.AddUint32(&d.nlink, 1) - return nil - }, nil) + vd.Dentry().Impl().(*dentry).incLinks() + } + return err } // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) { // If the parent is a setgid directory, use the parent's GID // rather than the caller's and enable setgid. kgid := creds.EffectiveKGID @@ -701,9 +878,18 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kgid = auth.KGID(atomic.LoadUint32(&parent.gid)) mode |= linux.S_ISGID } - if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil { + var ( + childDirInode *lisafs.Inode + err error + ) + if fs.opts.lisaEnabled { + childDirInode, err = parent.controlFDLisa.MkdirAt(ctx, name, mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid)) + } else { + _, err = parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)) + } + if err != nil { if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) { - return err + return nil, err } ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err) parent.createSyntheticChildLocked(&createSyntheticOpts{ @@ -717,7 +903,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if fs.opts.interop != InteropModeShared { parent.incLinks() } - return nil + return childDirInode, nil }, func(parent *dentry, name string) error { if !opts.ForSyntheticMountpoint { // Can't create non-synthetic files in synthetic directories. @@ -731,16 +917,26 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v }) parent.incLinks() return nil - }) + }, nil) } // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) { creds := rp.Credentials() - _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - if !linuxerr.Equals(linuxerr.EPERM, err) { - return err + var ( + childInode *lisafs.Inode + err error + ) + if fs.opts.lisaEnabled { + childInode, err = parent.controlFDLisa.MknodAt(ctx, name, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID), opts.DevMinor, opts.DevMajor) + } else { + _, err = parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) + } + if err == nil { + return childInode, nil + } else if !linuxerr.Equals(linuxerr.EPERM, err) { + return nil, err } // EPERM means that gofer does not allow creating a socket or pipe. Fallback @@ -751,10 +947,10 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v switch { case err == nil: // Step succeeded, another file exists. - return linuxerr.EEXIST + return nil, linuxerr.EEXIST case !linuxerr.Equals(linuxerr.ENOENT, err): // Unexpected error. - return err + return nil, err } switch opts.Mode.FileType() { @@ -767,7 +963,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v endpoint: opts.Endpoint, }) *ds = appendDentry(*ds, parent) - return nil + return nil, nil case linux.S_IFIFO: parent.createSyntheticChildLocked(&createSyntheticOpts{ name: name, @@ -777,11 +973,11 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) *ds = appendDentry(*ds, parent) - return nil + return nil, nil } // Retain error from gofer if synthetic file cannot be created internally. - return linuxerr.EPERM - }, nil) + return nil, linuxerr.EPERM + }, nil, nil) } // OpenAt implements vfs.FilesystemImpl.OpenAt. @@ -811,7 +1007,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if rp.Done() { // Reject attempts to open mount root directory with O_CREAT. if mayCreate && rp.MustBeDir() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if mustCreate { return nil, linuxerr.EEXIST @@ -841,7 +1037,7 @@ afterTrailingSymlink: } // Reject attempts to open directories with O_CREAT. if mayCreate && rp.MustBeDir() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, rp.Component(), &ds); err != nil { return nil, err @@ -922,11 +1118,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open case linux.S_IFDIR: // Can't open directories with O_CREAT. if opts.Flags&linux.O_CREAT != 0 { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } // Can't open directories writably. if ats&vfs.MayWrite != 0 { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if opts.Flags&linux.O_DIRECT != 0 { return nil, linuxerr.EINVAL @@ -987,6 +1183,23 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio if opts.Flags&linux.O_DIRECT != 0 { return nil, linuxerr.EINVAL } + if d.fs.opts.lisaEnabled { + // Note that special value of linux.SockType = 0 is interpreted by lisafs + // as "do not care about the socket type". Analogous to p9.AnonymousSocket. + sockFD, err := d.controlFDLisa.Connect(ctx, 0 /* sockType */) + if err != nil { + return nil, err + } + fd, err := host.NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), sockFD, &host.NewFDOptions{ + HaveFlags: true, + Flags: opts.Flags, + }) + if err != nil { + unix.Close(sockFD) + return nil, err + } + return fd, nil + } fdObj, err := d.file.connect(ctx, p9.AnonymousSocket) if err != nil { return nil, err @@ -999,6 +1212,7 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio fdObj.Close() return nil, err } + // Ownership has been transferred to fd. fdObj.Release() return fd, nil } @@ -1018,7 +1232,13 @@ func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs. // since closed its end. isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0 retry: - h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) + var h handle + var err error + if d.fs.opts.lisaEnabled { + h, err = openHandleLisa(ctx, d.controlFDLisa, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) + } else { + h, err = openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) + } if err != nil { if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && linuxerr.Equals(linuxerr.ENXIO, err) { // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails @@ -1054,7 +1274,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving return nil, err } if d.isDeleted() { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { @@ -1062,18 +1282,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } defer mnt.EndWrite() - // 9P2000.L's lcreate takes a fid representing the parent directory, and - // converts it into an open fid representing the created file, so we need - // to duplicate the directory fid first. - _, dirfile, err := d.file.walk(ctx, nil) - if err != nil { - return nil, err - } creds := rp.Credentials() name := rp.Component() - // We only want the access mode for creating the file. - createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask - // If the parent is a setgid directory, use the parent's GID rather // than the caller's. kgid := creds.EffectiveKGID @@ -1081,51 +1291,87 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving kgid = auth.KGID(atomic.LoadUint32(&d.gid)) } - fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)) - if err != nil { - dirfile.close(ctx) - return nil, err - } - // Then we need to walk to the file we just created to get a non-open fid - // representing it, and to get its metadata. This must use d.file since, as - // explained above, dirfile was invalidated by dirfile.Create(). - _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name) - if err != nil { - openFile.close(ctx) - if fdobj != nil { - fdobj.Close() + var child *dentry + var openP9File p9file + openLisaFD := lisafs.InvalidFDID + openHostFD := int32(-1) + if d.fs.opts.lisaEnabled { + ino, openFD, hostFD, err := d.controlFDLisa.OpenCreateAt(ctx, name, opts.Flags&linux.O_ACCMODE, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid)) + if err != nil { + return nil, err + } + openHostFD = int32(hostFD) + openLisaFD = openFD + + child, err = d.fs.newDentryLisa(ctx, &ino) + if err != nil { + d.fs.clientLisa.CloseFDBatched(ctx, ino.ControlFD) + d.fs.clientLisa.CloseFDBatched(ctx, openFD) + if hostFD >= 0 { + unix.Close(hostFD) + } + return nil, err + } + } else { + // 9P2000.L's lcreate takes a fid representing the parent directory, and + // converts it into an open fid representing the created file, so we need + // to duplicate the directory fid first. + _, dirfile, err := d.file.walk(ctx, nil) + if err != nil { + return nil, err + } + // We only want the access mode for creating the file. + createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask + + fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)) + if err != nil { + dirfile.close(ctx) + return nil, err + } + // Then we need to walk to the file we just created to get a non-open fid + // representing it, and to get its metadata. This must use d.file since, as + // explained above, dirfile was invalidated by dirfile.Create(). + _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name) + if err != nil { + openFile.close(ctx) + if fdobj != nil { + fdobj.Close() + } + return nil, err + } + + // Construct the new dentry. + child, err = d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr) + if err != nil { + nonOpenFile.close(ctx) + openFile.close(ctx) + if fdobj != nil { + fdobj.Close() + } + return nil, err } - return nil, err - } - // Construct the new dentry. - child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr) - if err != nil { - nonOpenFile.close(ctx) - openFile.close(ctx) if fdobj != nil { - fdobj.Close() + openHostFD = int32(fdobj.Release()) } - return nil, err + openP9File = openFile } // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { - openFD := int32(-1) - if fdobj != nil { - openFD = int32(fdobj.Release()) - } child.handleMu.Lock() if vfs.MayReadFileWithOpenFlags(opts.Flags) { - child.readFile = openFile - if fdobj != nil { - child.readFD = openFD - child.mmapFD = openFD + child.readFile = openP9File + child.readFDLisa = d.fs.clientLisa.NewFD(openLisaFD) + if openHostFD != -1 { + child.readFD = openHostFD + child.mmapFD = openHostFD } } if vfs.MayWriteFileWithOpenFlags(opts.Flags) { - child.writeFile = openFile - child.writeFD = openFD + child.writeFile = openP9File + child.writeFDLisa = d.fs.clientLisa.NewFD(openLisaFD) + child.writeFD = openHostFD } child.handleMu.Unlock() } @@ -1147,11 +1393,9 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving childVFSFD = &fd.vfsfd } else { h := handle{ - file: openFile, - fd: -1, - } - if fdobj != nil { - h.fd = int32(fdobj.Release()) + file: openP9File, + fdLisa: d.fs.clientLisa.NewFD(openLisaFD), + fd: openHostFD, } fd, err := newSpecialFileFD(h, mnt, child, opts.Flags) if err != nil { @@ -1268,7 +1512,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa defer newParent.dirMu.Unlock() } if newParent.isDeleted() { - return syserror.ENOENT + return linuxerr.ENOENT } replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds) if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { @@ -1282,7 +1526,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { - return syserror.EISDIR + return linuxerr.EISDIR } if genericIsAncestorDentry(replaced, renamed) { return linuxerr.ENOTEMPTY @@ -1305,7 +1549,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Update the remote filesystem. if !renamed.isSynthetic() { - if err := renamed.file.rename(ctx, newParent.file, newName); err != nil { + if fs.opts.lisaEnabled { + err = renamed.controlFDLisa.RenameTo(ctx, newParent.controlFDLisa.ID(), newName) + } else { + err = renamed.file.rename(ctx, newParent.file, newName) + } + if err != nil { vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) return err } @@ -1316,7 +1565,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if replaced.isDir() { flags = linux.AT_REMOVEDIR } - if err := newParent.file.unlinkAt(ctx, newName, flags); err != nil { + if fs.opts.lisaEnabled { + err = newParent.controlFDLisa.UnlinkAt(ctx, newName, flags) + } else { + err = newParent.file.unlinkAt(ctx, newName, flags) + } + if err != nil { vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) return err } @@ -1432,6 +1686,28 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu for d.isSynthetic() { d = d.parent } + if fs.opts.lisaEnabled { + var statFS lisafs.StatFS + if err := d.controlFDLisa.StatFSTo(ctx, &statFS); err != nil { + return linux.Statfs{}, err + } + if statFS.NameLength > maxFilenameLen { + statFS.NameLength = maxFilenameLen + } + return linux.Statfs{ + // This is primarily for distinguishing a gofer file system in + // tests. Testing is important, so instead of defining + // something completely random, use a standard value. + Type: linux.V9FS_MAGIC, + BlockSize: statFS.BlockSize, + Blocks: statFS.Blocks, + BlocksFree: statFS.BlocksFree, + BlocksAvailable: statFS.BlocksAvailable, + Files: statFS.Files, + FilesFree: statFS.FilesFree, + NameLength: statFS.NameLength, + }, nil + } fsstat, err := d.file.statFS(ctx) if err != nil { return linux.Statfs{}, err @@ -1457,11 +1733,21 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) { creds := rp.Credentials() + if fs.opts.lisaEnabled { + return parent.controlFDLisa.SymlinkAt(ctx, name, target, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID)) + } _, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - return err - }, nil) + return nil, err + }, nil, func(child *dentry) { + if fs.opts.interop != InteropModeShared { + // lisafs caches the symlink target on creation. In practice, this + // helps avoid a lot of ReadLink RPCs. + child.haveTarget = true + child.target = target + } + }) } // UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. @@ -1506,7 +1792,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si if err != nil { return nil, err } - return d.listXattr(ctx, rp.Credentials(), size) + return d.listXattr(ctx, size) } // GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. @@ -1613,6 +1899,9 @@ func (fs *filesystem) MountOptions() string { if fs.opts.overlayfsStaleRead { optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil}) } + if fs.opts.lisaEnabled { + optsKV = append(optsKV, mopt{moptLisafs, nil}) + } opts := make([]string, 0, len(optsKV)) for _, opt := range optsKV { diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 25d2e39d6..b98825e26 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -48,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" refs_vfs1 "gvisor.dev/gvisor/pkg/refs" @@ -62,7 +63,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" ) @@ -84,6 +84,7 @@ const ( moptForcePageCache = "force_page_cache" moptLimitHostFDTranslation = "limit_host_fd_translation" moptOverlayfsStaleRead = "overlayfs_stale_read" + moptLisafs = "lisafs" ) // Valid values for the "cache" mount option. @@ -119,6 +120,10 @@ type filesystem struct { // client is the client used by this filesystem. client is immutable. client *p9.Client `state:"nosave"` + // clientLisa is the client used for communicating with the server when + // lisafs is enabled. lisafsCient is immutable. + clientLisa *lisafs.Client `state:"nosave"` + // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -162,6 +167,12 @@ type filesystem struct { inoMu sync.Mutex `state:"nosave"` inoByQIDPath map[uint64]uint64 `state:"nosave"` + // inoByKey is the same as inoByQIDPath but only used by lisafs. It helps + // identify inodes based on the device ID and host inode number provided + // by the gofer process. It is not preserved across checkpoint/restore for + // the same reason as above. inoByKey is protected by inoMu. + inoByKey map[inoKey]uint64 `state:"nosave"` + // lastIno is the last inode number assigned to a file. lastIno is accessed // using atomic memory operations. lastIno uint64 @@ -215,6 +226,10 @@ type filesystemOptions struct { // way that application FDs representing "special files" such as sockets // do. Note that this disables client caching and mmap for regular files. regularFilesUseSpecialFileFD bool + + // lisaEnabled indicates whether the client will use lisafs protocol to + // communicate with the server instead of 9P. + lisaEnabled bool } // InteropMode controls the client's interaction with other remote filesystem @@ -428,6 +443,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt delete(mopts, moptOverlayfsStaleRead) fsopts.overlayfsStaleRead = true } + if lisafs, ok := mopts[moptLisafs]; ok { + delete(mopts, moptLisafs) + fsopts.lisaEnabled, err = strconv.ParseBool(lisafs) + if err != nil { + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid lisafs option: %s", lisafs) + return nil, nil, linuxerr.EINVAL + } + } // fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying // "cache=none". @@ -459,44 +482,83 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt syncableDentries: make(map[*dentry]struct{}), specialFileFDs: make(map[*specialFileFD]struct{}), inoByQIDPath: make(map[uint64]uint64), + inoByKey: make(map[inoKey]uint64), } fs.vfsfs.Init(vfsObj, &fstype, fs) + if err := fs.initClientAndRoot(ctx); err != nil { + fs.vfsfs.DecRef(ctx) + return nil, nil, err + } + + return &fs.vfsfs, &fs.root.vfsd, nil +} + +func (fs *filesystem) initClientAndRoot(ctx context.Context) error { + var err error + if fs.opts.lisaEnabled { + var rootInode *lisafs.Inode + rootInode, err = fs.initClientLisa(ctx) + if err != nil { + return err + } + fs.root, err = fs.newDentryLisa(ctx, rootInode) + if err != nil { + fs.clientLisa.CloseFDBatched(ctx, rootInode.ControlFD) + } + } else { + fs.root, err = fs.initClient(ctx) + } + + // Set the root's reference count to 2. One reference is returned to the + // caller, and the other is held by fs to prevent the root from being "cached" + // and subsequently evicted. + if err == nil { + fs.root.refs = 2 + } + return err +} + +func (fs *filesystem) initClientLisa(ctx context.Context) (*lisafs.Inode, error) { + sock, err := unet.NewSocket(fs.opts.fd) + if err != nil { + return nil, err + } + + var rootInode *lisafs.Inode + ctx.UninterruptibleSleepStart(false) + fs.clientLisa, rootInode, err = lisafs.NewClient(sock, fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + return rootInode, err +} + +func (fs *filesystem) initClient(ctx context.Context) (*dentry, error) { // Connect to the server. if err := fs.dial(ctx); err != nil { - return nil, nil, err + return nil, err } // Perform attach to obtain the filesystem root. ctx.UninterruptibleSleepStart(false) - attached, err := fs.client.Attach(fsopts.aname) + attached, err := fs.client.Attach(fs.opts.aname) ctx.UninterruptibleSleepFinish(false) if err != nil { - fs.vfsfs.DecRef(ctx) - return nil, nil, err + return nil, err } attachFile := p9file{attached} qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) if err != nil { attachFile.close(ctx) - fs.vfsfs.DecRef(ctx) - return nil, nil, err + return nil, err } // Construct the root dentry. root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) if err != nil { attachFile.close(ctx) - fs.vfsfs.DecRef(ctx) - return nil, nil, err + return nil, err } - // Set the root's reference count to 2. One reference is returned to the - // caller, and the other is held by fs to prevent the root from being "cached" - // and subsequently evicted. - root.refs = 2 - fs.root = root - - return &fs.vfsfs, &root.vfsd, nil + return root, nil } func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { @@ -614,7 +676,11 @@ func (fs *filesystem) Release(ctx context.Context) { if !fs.iopts.LeakConnection { // Close the connection to the server. This implicitly clunks all fids. - fs.client.Close() + if fs.opts.lisaEnabled { + fs.clientLisa.Close() + } else { + fs.client.Close() + } } fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) @@ -645,6 +711,23 @@ func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { } } +// inoKey is the key used to identify the inode backed by this dentry. +// +// +stateify savable +type inoKey struct { + ino uint64 + devMinor uint32 + devMajor uint32 +} + +func inoKeyFromStat(stat *linux.Statx) inoKey { + return inoKey{ + ino: stat.Ino, + devMinor: stat.DevMinor, + devMajor: stat.DevMajor, + } +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -675,6 +758,9 @@ type dentry struct { // qidPath is the p9.QID.Path for this file. qidPath is immutable. qidPath uint64 + // inoKey is used to identify this dentry's inode. + inoKey inoKey + // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file @@ -682,6 +768,14 @@ type dentry struct { // only files that can be synthetic are sockets, pipes, and directories. file p9file `state:"nosave"` + // controlFDLisa is used by lisafs to perform path based operations on this + // dentry. + // + // if !controlFDLisa.Ok(), this dentry represents a synthetic file, i.e. a + // file that does not exist on the remote filesystem. As of this writing, the + // only files that can be synthetic are sockets, pipes, and directories. + controlFDLisa lisafs.ClientFD `state:"nosave"` + // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. deleted uint32 @@ -792,12 +886,14 @@ type dentry struct { // always either -1 or equal to readFD; if !writeFile.isNil() (the file has // been opened for writing), it is additionally either -1 or equal to // writeFD. - handleMu sync.RWMutex `state:"nosave"` - readFile p9file `state:"nosave"` - writeFile p9file `state:"nosave"` - readFD int32 `state:"nosave"` - writeFD int32 `state:"nosave"` - mmapFD int32 `state:"nosave"` + handleMu sync.RWMutex `state:"nosave"` + readFile p9file `state:"nosave"` + writeFile p9file `state:"nosave"` + readFDLisa lisafs.ClientFD `state:"nosave"` + writeFDLisa lisafs.ClientFD `state:"nosave"` + readFD int32 `state:"nosave"` + writeFD int32 `state:"nosave"` + mmapFD int32 `state:"nosave"` dataMu sync.RWMutex `state:"nosave"` @@ -865,11 +961,11 @@ func dentryAttrMask() p9.AttrMask { func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, mask p9.AttrMask, attr *p9.Attr) (*dentry, error) { if !mask.Mode { ctx.Warningf("can't create gofer.dentry without file type") - return nil, syserror.EIO + return nil, linuxerr.EIO } if attr.Mode.FileType() == p9.ModeRegular && !mask.Size { ctx.Warningf("can't create regular file gofer.dentry without file size") - return nil, syserror.EIO + return nil, linuxerr.EIO } d := &dentry{ @@ -921,6 +1017,79 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma return d, nil } +func (fs *filesystem) newDentryLisa(ctx context.Context, ino *lisafs.Inode) (*dentry, error) { + if ino.Stat.Mask&linux.STATX_TYPE == 0 { + ctx.Warningf("can't create gofer.dentry without file type") + return nil, linuxerr.EIO + } + if ino.Stat.Mode&linux.FileTypeMask == linux.ModeRegular && ino.Stat.Mask&linux.STATX_SIZE == 0 { + ctx.Warningf("can't create regular file gofer.dentry without file size") + return nil, linuxerr.EIO + } + + inoKey := inoKeyFromStat(&ino.Stat) + d := &dentry{ + fs: fs, + inoKey: inoKey, + ino: fs.inoFromKey(inoKey), + mode: uint32(ino.Stat.Mode), + uid: uint32(fs.opts.dfltuid), + gid: uint32(fs.opts.dfltgid), + blockSize: hostarch.PageSize, + readFD: -1, + writeFD: -1, + mmapFD: -1, + controlFDLisa: fs.clientLisa.NewFD(ino.ControlFD), + } + + d.pf.dentry = d + if ino.Stat.Mask&linux.STATX_UID != 0 { + d.uid = dentryUIDFromLisaUID(lisafs.UID(ino.Stat.UID)) + } + if ino.Stat.Mask&linux.STATX_GID != 0 { + d.gid = dentryGIDFromLisaGID(lisafs.GID(ino.Stat.GID)) + } + if ino.Stat.Mask&linux.STATX_SIZE != 0 { + d.size = ino.Stat.Size + } + if ino.Stat.Blksize != 0 { + d.blockSize = ino.Stat.Blksize + } + if ino.Stat.Mask&linux.STATX_ATIME != 0 { + d.atime = dentryTimestampFromLisa(ino.Stat.Atime) + } + if ino.Stat.Mask&linux.STATX_MTIME != 0 { + d.mtime = dentryTimestampFromLisa(ino.Stat.Mtime) + } + if ino.Stat.Mask&linux.STATX_CTIME != 0 { + d.ctime = dentryTimestampFromLisa(ino.Stat.Ctime) + } + if ino.Stat.Mask&linux.STATX_BTIME != 0 { + d.btime = dentryTimestampFromLisa(ino.Stat.Btime) + } + if ino.Stat.Mask&linux.STATX_NLINK != 0 { + d.nlink = ino.Stat.Nlink + } + d.vfsd.Init(d) + refsvfs2.Register(d) + fs.syncMu.Lock() + fs.syncableDentries[d] = struct{}{} + fs.syncMu.Unlock() + return d, nil +} + +func (fs *filesystem) inoFromKey(key inoKey) uint64 { + fs.inoMu.Lock() + defer fs.inoMu.Unlock() + + if ino, ok := fs.inoByKey[key]; ok { + return ino + } + ino := fs.nextIno() + fs.inoByKey[key] = ino + return ino +} + func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 { fs.inoMu.Lock() defer fs.inoMu.Unlock() @@ -937,7 +1106,7 @@ func (fs *filesystem) nextIno() uint64 { } func (d *dentry) isSynthetic() bool { - return d.file.isNil() + return !d.isControlFileOk() } func (d *dentry) cachedMetadataAuthoritative() bool { @@ -987,6 +1156,50 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } } +// updateFromLisaStatLocked is called to update d's metadata after an update +// from the remote filesystem. +// Precondition: d.metadataMu must be locked. +// +checklocks:d.metadataMu +func (d *dentry) updateFromLisaStatLocked(stat *linux.Statx) { + if stat.Mask&linux.STATX_TYPE != 0 { + if got, want := stat.Mode&linux.FileTypeMask, d.fileType(); uint32(got) != want { + panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got)) + } + } + if stat.Mask&linux.STATX_MODE != 0 { + atomic.StoreUint32(&d.mode, uint32(stat.Mode)) + } + if stat.Mask&linux.STATX_UID != 0 { + atomic.StoreUint32(&d.uid, dentryUIDFromLisaUID(lisafs.UID(stat.UID))) + } + if stat.Mask&linux.STATX_GID != 0 { + atomic.StoreUint32(&d.uid, dentryGIDFromLisaGID(lisafs.GID(stat.GID))) + } + if stat.Blksize != 0 { + atomic.StoreUint32(&d.blockSize, stat.Blksize) + } + // Don't override newer client-defined timestamps with old server-defined + // ones. + if stat.Mask&linux.STATX_ATIME != 0 && atomic.LoadUint32(&d.atimeDirty) == 0 { + atomic.StoreInt64(&d.atime, dentryTimestampFromLisa(stat.Atime)) + } + if stat.Mask&linux.STATX_MTIME != 0 && atomic.LoadUint32(&d.mtimeDirty) == 0 { + atomic.StoreInt64(&d.mtime, dentryTimestampFromLisa(stat.Mtime)) + } + if stat.Mask&linux.STATX_CTIME != 0 { + atomic.StoreInt64(&d.ctime, dentryTimestampFromLisa(stat.Ctime)) + } + if stat.Mask&linux.STATX_BTIME != 0 { + atomic.StoreInt64(&d.btime, dentryTimestampFromLisa(stat.Btime)) + } + if stat.Mask&linux.STATX_NLINK != 0 { + atomic.StoreUint32(&d.nlink, stat.Nlink) + } + if stat.Mask&linux.STATX_SIZE != 0 { + d.updateSizeLocked(stat.Size) + } +} + // Preconditions: !d.isSynthetic(). // Preconditions: d.metadataMu is locked. // +checklocks:d.metadataMu @@ -996,7 +1209,10 @@ func (d *dentry) refreshSizeLocked(ctx context.Context) error { if d.writeFD < 0 { d.handleMu.RUnlock() // Ask the gofer if we don't have a host FD. - return d.updateFromGetattrLocked(ctx) + if d.fs.opts.lisaEnabled { + return d.updateFromStatLisaLocked(ctx, nil) + } + return d.updateFromGetattrLocked(ctx, p9file{}) } var stat unix.Statx_t @@ -1015,33 +1231,77 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { // updating stale attributes in d.updateFromP9AttrsLocked(). d.metadataMu.Lock() defer d.metadataMu.Unlock() - return d.updateFromGetattrLocked(ctx) + if d.fs.opts.lisaEnabled { + return d.updateFromStatLisaLocked(ctx, nil) + } + return d.updateFromGetattrLocked(ctx, p9file{}) } // Preconditions: // * !d.isSynthetic(). // * d.metadataMu is locked. // +checklocks:d.metadataMu -func (d *dentry) updateFromGetattrLocked(ctx context.Context) error { - // Use d.readFile or d.writeFile, which represent 9P FIDs that have been - // opened, in preference to d.file, which represents a 9P fid that has not. - // This may be significantly more efficient in some implementations. Prefer - // d.writeFile over d.readFile since some filesystem implementations may - // update a writable handle's metadata after writes to that handle, without - // making metadata updates immediately visible to read-only handles - // representing the same file. - d.handleMu.RLock() - handleMuRLocked := true - var file p9file - switch { - case !d.writeFile.isNil(): - file = d.writeFile - case !d.readFile.isNil(): - file = d.readFile - default: - file = d.file - d.handleMu.RUnlock() - handleMuRLocked = false +func (d *dentry) updateFromStatLisaLocked(ctx context.Context, fdLisa *lisafs.ClientFD) error { + handleMuRLocked := false + if fdLisa == nil { + // Use open FDs in preferenece to the control FD. This may be significantly + // more efficient in some implementations. Prefer a writable FD over a + // readable one since some filesystem implementations may update a writable + // FD's metadata after writes, without making metadata updates immediately + // visible to read-only FDs representing the same file. + d.handleMu.RLock() + switch { + case d.writeFDLisa.Ok(): + fdLisa = &d.writeFDLisa + handleMuRLocked = true + case d.readFDLisa.Ok(): + fdLisa = &d.readFDLisa + handleMuRLocked = true + default: + fdLisa = &d.controlFDLisa + d.handleMu.RUnlock() + } + } + + var stat linux.Statx + err := fdLisa.StatTo(ctx, &stat) + if handleMuRLocked { + // handleMu must be released before updateFromLisaStatLocked(). + d.handleMu.RUnlock() // +checklocksforce: complex case. + } + if err != nil { + return err + } + d.updateFromLisaStatLocked(&stat) + return nil +} + +// Preconditions: +// * !d.isSynthetic(). +// * d.metadataMu is locked. +// +checklocks:d.metadataMu +func (d *dentry) updateFromGetattrLocked(ctx context.Context, file p9file) error { + handleMuRLocked := false + if file.isNil() { + // Use d.readFile or d.writeFile, which represent 9P FIDs that have + // been opened, in preference to d.file, which represents a 9P fid that + // has not. This may be significantly more efficient in some + // implementations. Prefer d.writeFile over d.readFile since some + // filesystem implementations may update a writable handle's metadata + // after writes to that handle, without making metadata updates + // immediately visible to read-only handles representing the same file. + d.handleMu.RLock() + switch { + case !d.writeFile.isNil(): + file = d.writeFile + handleMuRLocked = true + case !d.readFile.isNil(): + file = d.readFile + handleMuRLocked = true + default: + file = d.file + d.handleMu.RUnlock() + } } _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask()) @@ -1112,7 +1372,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs case linux.S_IFREG: // ok case linux.S_IFDIR: - return syserror.EISDIR + return linuxerr.EISDIR default: return linuxerr.EINVAL } @@ -1159,6 +1419,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs } } + // failureMask indicates which attributes could not be set on the remote + // filesystem. p9 returns an error if any of the attributes could not be set + // but that leads to inconsistency as the server could have set a few + // attributes successfully but a later failure will cause the successful ones + // to not be updated in the dentry cache. + var failureMask uint32 + var failureErr error if !d.isSynthetic() { if stat.Mask != 0 { if stat.Mask&linux.STATX_SIZE != 0 { @@ -1168,35 +1435,50 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // the remote file has been truncated). d.dataMu.Lock() } - if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0, - UID: stat.Mask&linux.STATX_UID != 0, - GID: stat.Mask&linux.STATX_GID != 0, - Size: stat.Mask&linux.STATX_SIZE != 0, - ATime: stat.Mask&linux.STATX_ATIME != 0, - MTime: stat.Mask&linux.STATX_MTIME != 0, - ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW, - MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW, - }, p9.SetAttr{ - Permissions: p9.FileMode(stat.Mode), - UID: p9.UID(stat.UID), - GID: p9.GID(stat.GID), - Size: stat.Size, - ATimeSeconds: uint64(stat.Atime.Sec), - ATimeNanoSeconds: uint64(stat.Atime.Nsec), - MTimeSeconds: uint64(stat.Mtime.Sec), - MTimeNanoSeconds: uint64(stat.Mtime.Nsec), - }); err != nil { - if stat.Mask&linux.STATX_SIZE != 0 { - d.dataMu.Unlock() // +checklocksforce: locked conditionally above + if d.fs.opts.lisaEnabled { + var err error + failureMask, failureErr, err = d.controlFDLisa.SetStat(ctx, stat) + if err != nil { + if stat.Mask&linux.STATX_SIZE != 0 { + d.dataMu.Unlock() // +checklocksforce: locked conditionally above + } + return err + } + } else { + if err := d.file.setAttr(ctx, p9.SetAttrMask{ + Permissions: stat.Mask&linux.STATX_MODE != 0, + UID: stat.Mask&linux.STATX_UID != 0, + GID: stat.Mask&linux.STATX_GID != 0, + Size: stat.Mask&linux.STATX_SIZE != 0, + ATime: stat.Mask&linux.STATX_ATIME != 0, + MTime: stat.Mask&linux.STATX_MTIME != 0, + ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW, + MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW, + }, p9.SetAttr{ + Permissions: p9.FileMode(stat.Mode), + UID: p9.UID(stat.UID), + GID: p9.GID(stat.GID), + Size: stat.Size, + ATimeSeconds: uint64(stat.Atime.Sec), + ATimeNanoSeconds: uint64(stat.Atime.Nsec), + MTimeSeconds: uint64(stat.Mtime.Sec), + MTimeNanoSeconds: uint64(stat.Mtime.Nsec), + }); err != nil { + if stat.Mask&linux.STATX_SIZE != 0 { + d.dataMu.Unlock() // +checklocksforce: locked conditionally above + } + return err } - return err } if stat.Mask&linux.STATX_SIZE != 0 { - // d.size should be kept up to date, and privatized - // copy-on-write mappings of truncated pages need to be - // invalidated, even if InteropModeShared is in effect. - d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above + if failureMask&linux.STATX_SIZE == 0 { + // d.size should be kept up to date, and privatized + // copy-on-write mappings of truncated pages need to be + // invalidated, even if InteropModeShared is in effect. + d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above + } else { + d.dataMu.Unlock() // +checklocksforce: locked conditionally above + } } } if d.fs.opts.interop == InteropModeShared { @@ -1207,13 +1489,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - if stat.Mask&linux.STATX_MODE != 0 { + if stat.Mask&linux.STATX_MODE != 0 && failureMask&linux.STATX_MODE == 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } - if stat.Mask&linux.STATX_UID != 0 { + if stat.Mask&linux.STATX_UID != 0 && failureMask&linux.STATX_UID == 0 { atomic.StoreUint32(&d.uid, stat.UID) } - if stat.Mask&linux.STATX_GID != 0 { + if stat.Mask&linux.STATX_GID != 0 && failureMask&linux.STATX_GID == 0 { atomic.StoreUint32(&d.gid, stat.GID) } // Note that stat.Atime.Nsec and stat.Mtime.Nsec can't be UTIME_NOW because @@ -1221,15 +1503,19 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // stat.Mtime to client-local timestamps above, and if // !d.cachedMetadataAuthoritative() then we returned after calling // d.file.setAttr(). For the same reason, now must have been initialized. - if stat.Mask&linux.STATX_ATIME != 0 { + if stat.Mask&linux.STATX_ATIME != 0 && failureMask&linux.STATX_ATIME == 0 { atomic.StoreInt64(&d.atime, stat.Atime.ToNsec()) atomic.StoreUint32(&d.atimeDirty, 0) } - if stat.Mask&linux.STATX_MTIME != 0 { + if stat.Mask&linux.STATX_MTIME != 0 && failureMask&linux.STATX_MTIME == 0 { atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec()) atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) + if failureMask != 0 { + // Setting some attribute failed on the remote filesystem. + return failureErr + } return nil } @@ -1309,7 +1595,10 @@ func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats // (b/148380782). Allow all other extended attributes to be passed through // to the remote filesystem. This is inconsistent with Linux's 9p client, // but consistent with other filesystems (e.g. FUSE). - if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) { + // + // NOTE(b/202533394): Also disallow "trusted" namespace for now. This is + // consistent with the VFS1 gofer client. + if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) || strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { return linuxerr.EOPNOTSUPP } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) @@ -1345,6 +1634,20 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 { return uint32(gid) } +func dentryUIDFromLisaUID(uid lisafs.UID) uint32 { + if !uid.Ok() { + return uint32(auth.OverflowUID) + } + return uint32(uid) +} + +func dentryGIDFromLisaGID(gid lisafs.GID) uint32 { + if !gid.Ok() { + return uint32(auth.OverflowGID) + } + return uint32(gid) +} + // IncRef implements vfs.DentryImpl.IncRef. func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against @@ -1653,15 +1956,24 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.dirty.RemoveAll() } d.dataMu.Unlock() - // Clunk open fids and close open host FDs. - if !d.readFile.isNil() { - _ = d.readFile.close(ctx) - } - if !d.writeFile.isNil() && d.readFile != d.writeFile { - _ = d.writeFile.close(ctx) + if d.fs.opts.lisaEnabled { + if d.readFDLisa.Ok() && d.readFDLisa.ID() != d.writeFDLisa.ID() { + d.readFDLisa.CloseBatched(ctx) + } + if d.writeFDLisa.Ok() { + d.writeFDLisa.CloseBatched(ctx) + } + } else { + // Clunk open fids and close open host FDs. + if !d.readFile.isNil() { + _ = d.readFile.close(ctx) + } + if !d.writeFile.isNil() && d.readFile != d.writeFile { + _ = d.writeFile.close(ctx) + } + d.readFile = p9file{} + d.writeFile = p9file{} } - d.readFile = p9file{} - d.writeFile = p9file{} if d.readFD >= 0 { _ = unix.Close(int(d.readFD)) } @@ -1673,7 +1985,7 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.mmapFD = -1 d.handleMu.Unlock() - if !d.file.isNil() { + if d.isControlFileOk() { // Note that it's possible that d.atimeDirty or d.mtimeDirty are true, // i.e. client and server timestamps may differ (because e.g. a client // write was serviced by the page cache, and only written back to the @@ -1682,10 +1994,16 @@ func (d *dentry) destroyLocked(ctx context.Context) { // instantiated for the same file would remain coherent. Unfortunately, // this turns out to be too expensive in many cases, so for now we // don't do this. - if err := d.file.close(ctx); err != nil { - log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err) + + // Close the control FD. + if d.fs.opts.lisaEnabled { + d.controlFDLisa.CloseBatched(ctx) + } else { + if err := d.file.close(ctx); err != nil { + log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err) + } + d.file = p9file{} } - d.file = p9file{} // Remove d from the set of syncable dentries. d.fs.syncMu.Lock() @@ -1711,10 +2029,29 @@ func (d *dentry) setDeleted() { atomic.StoreUint32(&d.deleted, 1) } -func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { - if d.file.isNil() { +func (d *dentry) isControlFileOk() bool { + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.Ok() + } + return !d.file.isNil() +} + +func (d *dentry) isReadFileOk() bool { + if d.fs.opts.lisaEnabled { + return d.readFDLisa.Ok() + } + return !d.readFile.isNil() +} + +func (d *dentry) listXattr(ctx context.Context, size uint64) ([]string, error) { + if !d.isControlFileOk() { return nil, nil } + + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.ListXattr(ctx, size) + } + xattrMap, err := d.file.listXattr(ctx, size) if err != nil { return nil, err @@ -1727,32 +2064,41 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui } func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { - if d.file.isNil() { + if !d.isControlFileOk() { return "", linuxerr.ENODATA } if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { return "", err } + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.GetXattr(ctx, opts.Name, opts.Size) + } return d.file.getXattr(ctx, opts.Name, opts.Size) } func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error { - if d.file.isNil() { + if !d.isControlFileOk() { return linuxerr.EPERM } if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { return err } + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.SetXattr(ctx, opts.Name, opts.Value, opts.Flags) + } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error { - if d.file.isNil() { + if !d.isControlFileOk() { return linuxerr.EPERM } if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { return err } + if d.fs.opts.lisaEnabled { + return d.controlFDLisa.RemoveXattr(ctx, name) + } return d.file.removeXattr(ctx, name) } @@ -1764,19 +2110,30 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // O_TRUNC). if !trunc { d.handleMu.RLock() - if (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) { + var canReuseCurHandle bool + if d.fs.opts.lisaEnabled { + canReuseCurHandle = (!read || d.readFDLisa.Ok()) && (!write || d.writeFDLisa.Ok()) + } else { + canReuseCurHandle = (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) + } + d.handleMu.RUnlock() + if canReuseCurHandle { // Current handles are sufficient. - d.handleMu.RUnlock() return nil } - d.handleMu.RUnlock() } var fdsToCloseArr [2]int32 fdsToClose := fdsToCloseArr[:0] invalidateTranslations := false d.handleMu.Lock() - if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc { + var needNewHandle bool + if d.fs.opts.lisaEnabled { + needNewHandle = (read && !d.readFDLisa.Ok()) || (write && !d.writeFDLisa.Ok()) || trunc + } else { + needNewHandle = (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc + } + if needNewHandle { // Get a new handle. If this file has been opened for both reading and // writing, try to get a single handle that is usable for both: // @@ -1785,9 +2142,21 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // // - NOTE(b/141991141): Some filesystems may not ensure coherence // between multiple handles for the same file. - openReadable := !d.readFile.isNil() || read - openWritable := !d.writeFile.isNil() || write - h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc) + var ( + openReadable bool + openWritable bool + h handle + err error + ) + if d.fs.opts.lisaEnabled { + openReadable = d.readFDLisa.Ok() || read + openWritable = d.writeFDLisa.Ok() || write + h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc) + } else { + openReadable = !d.readFile.isNil() || read + openWritable = !d.writeFile.isNil() || write + h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc) + } if linuxerr.Equals(linuxerr.EACCES, err) && (openReadable != read || openWritable != write) { // It may not be possible to use a single handle for both // reading and writing, since permissions on the file may have @@ -1797,7 +2166,11 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool ctx.Debugf("gofer.dentry.ensureSharedHandle: bifurcating read/write handles for dentry %p", d) openReadable = read openWritable = write - h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc) + if d.fs.opts.lisaEnabled { + h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc) + } else { + h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc) + } } if err != nil { d.handleMu.Unlock() @@ -1859,9 +2232,16 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // previously opened for reading (without an FD), then existing // translations of the file may use the internal page cache; // invalidate those mappings. - if d.writeFile.isNil() { - invalidateTranslations = !d.readFile.isNil() - atomic.StoreInt32(&d.mmapFD, h.fd) + if d.fs.opts.lisaEnabled { + if !d.writeFDLisa.Ok() { + invalidateTranslations = d.readFDLisa.Ok() + atomic.StoreInt32(&d.mmapFD, h.fd) + } + } else { + if d.writeFile.isNil() { + invalidateTranslations = !d.readFile.isNil() + atomic.StoreInt32(&d.mmapFD, h.fd) + } } } else if openWritable && d.writeFD < 0 { atomic.StoreInt32(&d.writeFD, h.fd) @@ -1888,24 +2268,45 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool atomic.StoreInt32(&d.mmapFD, -1) } - // Switch to new fids. - var oldReadFile p9file - if openReadable { - oldReadFile = d.readFile - d.readFile = h.file - } - var oldWriteFile p9file - if openWritable { - oldWriteFile = d.writeFile - d.writeFile = h.file - } - // NOTE(b/141991141): Clunk old fids before making new fids visible (by - // unlocking d.handleMu). - if !oldReadFile.isNil() { - oldReadFile.close(ctx) - } - if !oldWriteFile.isNil() && oldReadFile != oldWriteFile { - oldWriteFile.close(ctx) + // Switch to new fids/FDs. + if d.fs.opts.lisaEnabled { + oldReadFD := lisafs.InvalidFDID + if openReadable { + oldReadFD = d.readFDLisa.ID() + d.readFDLisa = h.fdLisa + } + oldWriteFD := lisafs.InvalidFDID + if openWritable { + oldWriteFD = d.writeFDLisa.ID() + d.writeFDLisa = h.fdLisa + } + // NOTE(b/141991141): Close old FDs before making new fids visible (by + // unlocking d.handleMu). + if oldReadFD.Ok() { + d.fs.clientLisa.CloseFDBatched(ctx, oldReadFD) + } + if oldWriteFD.Ok() && oldReadFD != oldWriteFD { + d.fs.clientLisa.CloseFDBatched(ctx, oldWriteFD) + } + } else { + var oldReadFile p9file + if openReadable { + oldReadFile = d.readFile + d.readFile = h.file + } + var oldWriteFile p9file + if openWritable { + oldWriteFile = d.writeFile + d.writeFile = h.file + } + // NOTE(b/141991141): Clunk old fids before making new fids visible (by + // unlocking d.handleMu). + if !oldReadFile.isNil() { + oldReadFile.close(ctx) + } + if !oldWriteFile.isNil() && oldReadFile != oldWriteFile { + oldWriteFile.close(ctx) + } } } d.handleMu.Unlock() @@ -1929,27 +2330,29 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // Preconditions: d.handleMu must be locked. func (d *dentry) readHandleLocked() handle { return handle{ - file: d.readFile, - fd: d.readFD, + fdLisa: d.readFDLisa, + file: d.readFile, + fd: d.readFD, } } // Preconditions: d.handleMu must be locked. func (d *dentry) writeHandleLocked() handle { return handle{ - file: d.writeFile, - fd: d.writeFD, + fdLisa: d.writeFDLisa, + file: d.writeFile, + fd: d.writeFD, } } func (d *dentry) syncRemoteFile(ctx context.Context) error { d.handleMu.RLock() defer d.handleMu.RUnlock() - return d.syncRemoteFileLocked(ctx) + return d.syncRemoteFileLocked(ctx, nil /* accFsyncFDIDsLisa */) } // Preconditions: d.handleMu must be locked. -func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { +func (d *dentry) syncRemoteFileLocked(ctx context.Context, accFsyncFDIDsLisa *[]lisafs.FDID) error { // If we have a host FD, fsyncing it is likely to be faster than an fsync // RPC. Prefer syncing write handles over read handles, since some remote // filesystem implementations may not sync changes made through write @@ -1960,7 +2363,13 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { ctx.UninterruptibleSleepFinish(false) return err } - if !d.writeFile.isNil() { + if d.fs.opts.lisaEnabled && d.writeFDLisa.Ok() { + if accFsyncFDIDsLisa != nil { + *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.writeFDLisa.ID()) + return nil + } + return d.writeFDLisa.Sync(ctx) + } else if !d.fs.opts.lisaEnabled && !d.writeFile.isNil() { return d.writeFile.fsync(ctx) } if d.readFD >= 0 { @@ -1969,13 +2378,19 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { ctx.UninterruptibleSleepFinish(false) return err } - if !d.readFile.isNil() { + if d.fs.opts.lisaEnabled && d.readFDLisa.Ok() { + if accFsyncFDIDsLisa != nil { + *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.readFDLisa.ID()) + return nil + } + return d.readFDLisa.Sync(ctx) + } else if !d.fs.opts.lisaEnabled && !d.readFile.isNil() { return d.readFile.fsync(ctx) } return nil } -func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error { +func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error { d.handleMu.RLock() defer d.handleMu.RUnlock() h := d.writeHandleLocked() @@ -1988,7 +2403,7 @@ func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) err return err } } - if err := d.syncRemoteFileLocked(ctx); err != nil { + if err := d.syncRemoteFileLocked(ctx, accFsyncFDIDsLisa); err != nil { if !forFilesystemSync { return err } @@ -2045,10 +2460,33 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu d := fd.dentry() const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME) if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC { - // TODO(jamieliu): Use specialFileFD.handle.file for the getattr if - // available? - if err := d.updateFromGetattr(ctx); err != nil { - return linux.Statx{}, err + if d.fs.opts.lisaEnabled { + // Use specialFileFD.handle.fileLisa for the Stat if available, for the + // same reason that we try to use open FD in updateFromStatLisaLocked(). + var fdLisa *lisafs.ClientFD + if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok { + fdLisa = &sffd.handle.fdLisa + } + d.metadataMu.Lock() + err := d.updateFromStatLisaLocked(ctx, fdLisa) + d.metadataMu.Unlock() + if err != nil { + return linux.Statx{}, err + } + } else { + // Use specialFileFD.handle.file for the getattr if available, for the + // same reason that we try to use open file handles in + // dentry.updateFromGetattrLocked(). + var file p9file + if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok { + file = sffd.handle.file + } + d.metadataMu.Lock() + err := d.updateFromGetattrLocked(ctx, file) + d.metadataMu.Unlock() + if err != nil { + return linux.Statx{}, err + } } } var stat linux.Statx @@ -2069,7 +2507,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) // ListXattr implements vfs.FileDescriptionImpl.ListXattr. func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { - return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size) + return fd.dentry().listXattr(ctx, size) } // GetXattr implements vfs.FileDescriptionImpl.GetXattr. diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index 806392d50..d5cc73f33 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -33,6 +33,7 @@ func TestDestroyIdempotent(t *testing.T) { }, syncableDentries: make(map[*dentry]struct{}), inoByQIDPath: make(map[uint64]uint64), + inoByKey: make(map[inoKey]uint64), } attr := &p9.Attr{ diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go index 5c57f6fea..394aecd62 100644 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ b/pkg/sentry/fsimpl/gofer/handle.go @@ -17,18 +17,23 @@ package gofer import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/sync" ) // handle represents a remote "open file descriptor", consisting of an opened // fid (p9.File) and optionally a host file descriptor. // +// If lisafs is being used, fdLisa points to an open file on the server. +// // These are explicitly not savable. type handle struct { - file p9file - fd int32 // -1 if unavailable + fdLisa lisafs.ClientFD + file p9file + fd int32 // -1 if unavailable } // Preconditions: read || write. @@ -64,13 +69,47 @@ func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (hand }, nil } +// Preconditions: read || write. +func openHandleLisa(ctx context.Context, fdLisa lisafs.ClientFD, read, write, trunc bool) (handle, error) { + var flags uint32 + switch { + case read && write: + flags = unix.O_RDWR + case read: + flags = unix.O_RDONLY + case write: + flags = unix.O_WRONLY + default: + panic("tried to open unreadable and unwritable handle") + } + if trunc { + flags |= unix.O_TRUNC + } + openFD, hostFD, err := fdLisa.OpenAt(ctx, flags) + if err != nil { + return handle{fd: -1}, err + } + h := handle{ + fdLisa: fdLisa.Client().NewFD(openFD), + fd: int32(hostFD), + } + return h, nil +} + func (h *handle) isOpen() bool { + if h.fdLisa.Client() != nil { + return h.fdLisa.Ok() + } return !h.file.isNil() } func (h *handle) close(ctx context.Context) { - h.file.close(ctx) - h.file = p9file{} + if h.fdLisa.Client() != nil { + h.fdLisa.CloseBatched(ctx) + } else { + h.file.close(ctx) + h.file = p9file{} + } if h.fd >= 0 { unix.Close(int(h.fd)) h.fd = -1 @@ -88,19 +127,27 @@ func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offs return n, err } if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() { - n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset) - return uint64(n), err + if h.fdLisa.Client() != nil { + return h.fdLisa.Read(ctx, dsts.Head().ToSlice(), offset) + } + return h.file.readAt(ctx, dsts.Head().ToSlice(), offset) } // Buffer the read since p9.File.ReadAt() takes []byte. buf := make([]byte, dsts.NumBytes()) - n, err := h.file.readAt(ctx, buf, offset) + var n uint64 + var err error + if h.fdLisa.Client() != nil { + n, err = h.fdLisa.Read(ctx, buf, offset) + } else { + n, err = h.file.readAt(ctx, buf, offset) + } if n == 0 { return 0, err } if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil { return cp, cperr } - return uint64(n), err + return n, err } func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { @@ -114,8 +161,10 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o return n, err } if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() { - n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset) - return uint64(n), err + if h.fdLisa.Client() != nil { + return h.fdLisa.Write(ctx, srcs.Head().ToSlice(), offset) + } + return h.file.writeAt(ctx, srcs.Head().ToSlice(), offset) } // Buffer the write since p9.File.WriteAt() takes []byte. buf := make([]byte, srcs.NumBytes()) @@ -123,10 +172,56 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o if cp == 0 { return 0, cperr } - n, err := h.file.writeAt(ctx, buf[:cp], offset) + var n uint64 + var err error + if h.fdLisa.Client() != nil { + n, err = h.fdLisa.Write(ctx, buf[:cp], offset) + } else { + n, err = h.file.writeAt(ctx, buf[:cp], offset) + } // err takes precedence over cperr. if err != nil { - return uint64(n), err + return n, err } - return uint64(n), cperr + return n, cperr +} + +type handleReadWriter struct { + ctx context.Context + h *handle + off uint64 +} + +var handleReadWriterPool = sync.Pool{ + New: func() interface{} { + return &handleReadWriter{} + }, +} + +func getHandleReadWriter(ctx context.Context, h *handle, offset int64) *handleReadWriter { + rw := handleReadWriterPool.Get().(*handleReadWriter) + rw.ctx = ctx + rw.h = h + rw.off = uint64(offset) + return rw +} + +func putHandleReadWriter(rw *handleReadWriter) { + rw.ctx = nil + rw.h = nil + handleReadWriterPool.Put(rw) +} + +// ReadToBlocks implements safemem.Reader.ReadToBlocks. +func (rw *handleReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + n, err := rw.h.readToBlocksAt(rw.ctx, dsts, rw.off) + rw.off += n + return n, err +} + +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. +func (rw *handleReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + n, err := rw.h.writeFromBlocksAt(rw.ctx, srcs, rw.off) + rw.off += n + return n, err } diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go index 398288ee3..505916a57 100644 --- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -22,7 +22,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/syserror" ) // Global pipe used by blockUntilNonblockingPipeHasWriter since we can't create @@ -109,6 +108,6 @@ func sleepBetweenNamedPipeOpenChecks(ctx context.Context) error { return nil case <-cancel: ctx.SleepFinish(false) - return syserror.ErrInterrupted + return linuxerr.ErrInterrupted } } diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index b0a429d42..0d97b60fd 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -16,9 +16,9 @@ package gofer import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/syserror" ) // p9file is a wrapper around p9.File that provides methods that are @@ -59,7 +59,7 @@ func (f p9file) walkGetAttrOne(ctx context.Context, name string) (p9.QID, p9file if newfile != nil { p9file{newfile}.close(ctx) } - return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, syserror.EIO + return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, linuxerr.EIO } return qids[0], p9file{newfile}, attrMask, attr, nil } @@ -141,18 +141,18 @@ func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, u return fdobj, qid, iounit, err } -func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) { +func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (uint64, error) { ctx.UninterruptibleSleepStart(false) n, err := f.file.ReadAt(p, offset) ctx.UninterruptibleSleepFinish(false) - return n, err + return uint64(n), err } -func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) { +func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (uint64, error) { ctx.UninterruptibleSleepStart(false) n, err := f.file.WriteAt(p, offset) ctx.UninterruptibleSleepFinish(false) - return n, err + return uint64(n), err } func (f p9file) fsync(ctx context.Context) error { diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 947dbe05f..874f9873d 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -98,6 +98,12 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { } d.handleMu.RLock() defer d.handleMu.RUnlock() + if d.fs.opts.lisaEnabled { + if !d.writeFDLisa.Ok() { + return nil + } + return d.writeFDLisa.Flush(ctx) + } if d.writeFile.isNil() { return nil } @@ -110,6 +116,9 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint return d.doAllocate(ctx, offset, length, func() error { d.handleMu.RLock() defer d.handleMu.RUnlock() + if d.fs.opts.lisaEnabled { + return d.writeFDLisa.Allocate(ctx, mode, offset, length) + } return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) }) } @@ -282,8 +291,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off // changes to the host. if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode { atomic.StoreUint32(&d.mode, newMode) - if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil { - return 0, offset, err + if d.fs.opts.lisaEnabled { + stat := linux.Statx{Mask: linux.STATX_MODE, Mode: uint16(newMode)} + failureMask, failureErr, err := d.controlFDLisa.SetStat(ctx, &stat) + if err != nil { + return 0, offset, err + } + if failureMask != 0 { + return 0, offset, failureErr + } + } else { + if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil { + return 0, offset, err + } } } } @@ -677,7 +697,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */) + return fd.dentry().syncCachedFile(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go index 226790a11..5d4009832 100644 --- a/pkg/sentry/fsimpl/gofer/revalidate.go +++ b/pkg/sentry/fsimpl/gofer/revalidate.go @@ -15,7 +15,9 @@ package gofer import ( + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" ) @@ -234,28 +236,54 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF } // Lock metadata on all dentries *before* getting attributes for them. state.lockAllMetadata() - stats, err := state.start.file.multiGetAttr(ctx, state.names) - if err != nil { - return err + + var ( + stats []p9.FullStat + statsLisa []linux.Statx + numStats int + ) + if fs.opts.lisaEnabled { + var err error + statsLisa, err = state.start.controlFDLisa.WalkStat(ctx, state.names) + if err != nil { + return err + } + numStats = len(statsLisa) + } else { + var err error + stats, err = state.start.file.multiGetAttr(ctx, state.names) + if err != nil { + return err + } + numStats = len(stats) } i := -1 for d := state.popFront(); d != nil; d = state.popFront() { i++ - found := i < len(stats) + found := i < numStats if i == 0 && len(state.names[0]) == 0 { if found && !d.isSynthetic() { // First dentry is where the search is starting, just update attributes // since it cannot be replaced. - d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata. + if fs.opts.lisaEnabled { + d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: acquired by lockAllMetadata. + } else { + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata. + } } d.metadataMu.Unlock() // +checklocksforce: see above. continue } - // Note that synthetic dentries will always fails the comparison check - // below. - if !found || d.qidPath != stats[i].QID.Path { + // Note that synthetic dentries will always fail this comparison check. + var shouldInvalidate bool + if fs.opts.lisaEnabled { + shouldInvalidate = !found || d.inoKey != inoKeyFromStat(&statsLisa[i]) + } else { + shouldInvalidate = !found || d.qidPath != stats[i].QID.Path + } + if shouldInvalidate { d.metadataMu.Unlock() // +checklocksforce: see above. if !found && d.isSynthetic() { // We have a synthetic file, and no remote file has arisen to replace @@ -298,7 +326,11 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF } // The file at this path hasn't changed. Just update cached metadata. - d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above. + if fs.opts.lisaEnabled { + d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: see above. + } else { + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above. + } d.metadataMu.Unlock() } diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go index e67422a2f..475322527 100644 --- a/pkg/sentry/fsimpl/gofer/save_restore.go +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/safemem" @@ -112,10 +113,19 @@ func (d *dentry) prepareSaveRecursive(ctx context.Context) error { return err } } - if !d.readFile.isNil() || !d.writeFile.isNil() { - d.fs.savedDentryRW[d] = savedDentryRW{ - read: !d.readFile.isNil(), - write: !d.writeFile.isNil(), + if d.fs.opts.lisaEnabled { + if d.readFDLisa.Ok() || d.writeFDLisa.Ok() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: d.readFDLisa.Ok(), + write: d.writeFDLisa.Ok(), + } + } + } else { + if !d.readFile.isNil() || !d.writeFile.isNil() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: !d.readFile.isNil(), + write: !d.writeFile.isNil(), + } } } d.dirMu.Lock() @@ -158,6 +168,10 @@ func (d *dentryPlatformFile) afterLoad() { // afterLoad is invoked by stateify. func (fd *specialFileFD) afterLoad() { fd.handle.fd = -1 + if fd.hostFileMapper.IsInited() { + // Ensure that we don't call fd.hostFileMapper.Init() again. + fd.hostFileMapperInitOnce.Do(func() {}) + } } // CompleteRestore implements @@ -173,25 +187,37 @@ func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRest return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID) } fs.opts.fd = fd - if err := fs.dial(ctx); err != nil { - return err - } fs.inoByQIDPath = make(map[uint64]uint64) + fs.inoByKey = make(map[inoKey]uint64) - // Restore the filesystem root. - ctx.UninterruptibleSleepStart(false) - attached, err := fs.client.Attach(fs.opts.aname) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - return err - } - attachFile := p9file{attached} - qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) - if err != nil { - return err - } - if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { - return err + if fs.opts.lisaEnabled { + rootInode, err := fs.initClientLisa(ctx) + if err != nil { + return err + } + if err := fs.root.restoreFileLisa(ctx, rootInode, &opts); err != nil { + return err + } + } else { + if err := fs.dial(ctx); err != nil { + return err + } + + // Restore the filesystem root. + ctx.UninterruptibleSleepStart(false) + attached, err := fs.client.Attach(fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return err + } + attachFile := p9file{attached} + qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) + if err != nil { + return err + } + if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { + return err + } } // Restore remaining dentries. @@ -279,6 +305,55 @@ func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrM return nil } +func (d *dentry) restoreFileLisa(ctx context.Context, inode *lisafs.Inode, opts *vfs.CompleteRestoreOptions) error { + d.controlFDLisa = d.fs.clientLisa.NewFD(inode.ControlFD) + + // Gofers do not preserve inoKey across checkpoint/restore, so: + // + // - We must assume that the remote filesystem did not change in a way that + // would invalidate dentries, since we can't revalidate dentries by + // checking inoKey. + // + // - We need to associate the new inoKey with the existing d.ino. + d.inoKey = inoKeyFromStat(&inode.Stat) + d.fs.inoMu.Lock() + d.fs.inoByKey[d.inoKey] = d.ino + d.fs.inoMu.Unlock() + + // Check metadata stability before updating metadata. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + if d.isRegularFile() { + if opts.ValidateFileSizes { + if inode.Stat.Mask&linux.STATX_SIZE != 0 { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d)) + } + if d.size != inode.Stat.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, inode.Stat.Size) + } + } + if opts.ValidateFileModificationTimestamps { + if inode.Stat.Mask&linux.STATX_MTIME != 0 { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d)) + } + if want := dentryTimestampFromLisa(inode.Stat.Mtime); d.mtime != want { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want)) + } + } + } + if !d.cachedMetadataAuthoritative() { + d.updateFromLisaStatLocked(&inode.Stat) + } + + if rw, ok := d.fs.savedDentryRW[d]; ok { + if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil { + return err + } + } + + return nil +} + // Preconditions: d is not synthetic. func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { for _, child := range d.children { @@ -301,19 +376,35 @@ func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.Comp // only be detected by checking filesystem.syncableDentries). d.parent has been // restored. func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { - qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) - if err != nil { - return err - } - if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { - return err + if d.fs.opts.lisaEnabled { + inode, err := d.parent.controlFDLisa.Walk(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFileLisa(ctx, inode, opts); err != nil { + return err + } + } else { + qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { + return err + } } return d.restoreDescendantsRecursive(ctx, opts) } func (fd *specialFileFD) completeRestore(ctx context.Context) error { d := fd.dentry() - h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + var h handle + var err error + if d.fs.opts.lisaEnabled { + h, err = openHandleLisa(ctx, d.controlFDLisa, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + } else { + h, err = openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + } if err != nil { return err } diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index fe15f8583..86ab70453 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -59,11 +59,6 @@ func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) { // BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect. func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error { - cf, ok := sockTypeToP9(ce.Type()) - if !ok { - return syserr.ErrConnectionRefused - } - // No lock ordering required as only the ConnectingEndpoint has a mutex. ce.Lock() @@ -77,7 +72,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec return syserr.ErrInvalidEndpointState } - c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue()) + c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue()) if err != nil { ce.Unlock() return err @@ -95,7 +90,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec // UnidirectionalConnect implements // transport.BoundEndpoint.UnidirectionalConnect. func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) { - c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{}) + c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{}) if err != nil { return nil, err } @@ -111,25 +106,39 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect return c, nil } -func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { - hostFile, err := e.dentry.file.connect(ctx, flags) - if err != nil { +func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { + if e.dentry.fs.opts.lisaEnabled { + hostSockFD, err := e.dentry.controlFDLisa.Connect(ctx, sockType) + if err != nil { + return nil, syserr.ErrConnectionRefused + } + + c, serr := host.NewSCMEndpoint(ctx, hostSockFD, queue, e.path) + if serr != nil { + unix.Close(hostSockFD) + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr) + return nil, serr + } + return c, nil + } + + flags, ok := sockTypeToP9(sockType) + if !ok { return nil, syserr.ErrConnectionRefused } - // Dup the fd so that the new endpoint can manage its lifetime. - hostFD, err := unix.Dup(hostFile.FD()) + hostFile, err := e.dentry.file.connect(ctx, flags) if err != nil { - log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err) - return nil, syserr.FromError(err) + return nil, syserr.ErrConnectionRefused } - // After duplicating, we no longer need hostFile. - hostFile.Close() - c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path) + c, serr := host.NewSCMEndpoint(ctx, hostFile.FD(), queue, e.path) if serr != nil { - log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr) + hostFile.Close() + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr) return nil, serr } + // Ownership has been transferred to c. + hostFile.Release() return c, nil } diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 4b59c1c3c..c568bbfd2 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -22,13 +22,16 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fsmetric" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -76,6 +79,16 @@ type specialFileFD struct { bufMu sync.Mutex `state:"nosave"` haveBuf uint32 buf []byte + + // If handle.fd >= 0, hostFileMapper caches mappings of handle.fd, and + // hostFileMapperInitOnce is used to initialize it on first use. + hostFileMapperInitOnce sync.Once `state:"nosave"` + hostFileMapper fsutil.HostFileMapper + + // If handle.fd >= 0, fileRefs counts references on memmap.File offsets. + // fileRefs is protected by fileRefsMu. + fileRefsMu sync.Mutex `state:"nosave"` + fileRefs fsutil.FrameRefSet } func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) { @@ -137,6 +150,9 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error { if !fd.vfsfd.IsWritable() { return nil } + if fs := fd.filesystem(); fs.opts.lisaEnabled { + return fd.handle.fdLisa.Flush(ctx) + } return fd.handle.file.flush(ctx) } @@ -172,6 +188,9 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint if fd.isRegularFile { d := fd.dentry() return d.doAllocate(ctx, offset, length, func() error { + if d.fs.opts.lisaEnabled { + return fd.handle.fdLisa.Allocate(ctx, mode, offset, length) + } return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) }) } @@ -230,23 +249,13 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs } } - // Going through dst.CopyOutFrom() would hold MM locks around file - // operations of unknown duration. For regularFileFD, doing so is necessary - // to support mmap due to lock ordering; MM locks precede dentry.dataMu. - // That doesn't hold here since specialFileFD doesn't client-cache data. - // Just buffer the read instead. - buf := make([]byte, dst.NumBytes()) - n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) + rw := getHandleReadWriter(ctx, &fd.handle, offset) + n, err := dst.CopyOutFrom(ctx, rw) + putHandleReadWriter(rw) if linuxerr.Equals(linuxerr.EAGAIN, err) { - err = syserror.ErrWouldBlock - } - if n == 0 { - return bufN, err + err = linuxerr.ErrWouldBlock } - if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil { - return bufN + int64(cp), cperr - } - return bufN + int64(n), err + return bufN + n, err } // Read implements vfs.FileDescriptionImpl.Read. @@ -317,20 +326,15 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } } - // Do a buffered write. See rationale in PRead. - buf := make([]byte, src.NumBytes()) - copied, copyErr := src.CopyIn(ctx, buf) - if copied == 0 && copyErr != nil { - // Only return the error if we didn't get any data. - return 0, offset, copyErr - } - n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset)) + rw := getHandleReadWriter(ctx, &fd.handle, offset) + n, err := src.CopyInTo(ctx, rw) + putHandleReadWriter(rw) if linuxerr.Equals(linuxerr.EAGAIN, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } // Update offset if the offset is valid. if offset >= 0 { - offset += int64(n) + offset += n } // Update file size for regular files. if fd.isRegularFile { @@ -341,10 +345,7 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off atomic.StoreUint64(&d.size, uint64(offset)) } } - if err != nil { - return int64(n), offset, err - } - return int64(n), offset, copyErr + return int64(n), offset, err } // Write implements vfs.FileDescriptionImpl.Write. @@ -377,10 +378,10 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - return fd.sync(ctx, false /* forFilesystemSync */) + return fd.sync(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */) } -func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error { +func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error { // Locks to ensure it didn't race with fd.Release(). fd.releaseMu.RLock() defer fd.releaseMu.RUnlock() @@ -397,6 +398,13 @@ func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error ctx.UninterruptibleSleepFinish(false) return err } + if fs := fd.filesystem(); fs.opts.lisaEnabled { + if accFsyncFDIDsLisa != nil { + *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, fd.handle.fdLisa.ID()) + return nil + } + return fd.handle.fdLisa.Sync(ctx) + } return fd.handle.file.fsync(ctx) }() if err != nil { @@ -412,3 +420,85 @@ func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error } return nil } + +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *specialFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + if fd.handle.fd < 0 || fd.filesystem().opts.forcePageCache { + return linuxerr.ENODEV + } + // After this point, fd may be used as a memmap.Mappable and memmap.File. + fd.hostFileMapperInitOnce.Do(fd.hostFileMapper.Init) + return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts) +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (fd *specialFileFD) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { + fd.hostFileMapper.IncRefOn(memmap.MappableRange{offset, offset + uint64(ar.Length())}) + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (fd *specialFileFD) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { + fd.hostFileMapper.DecRefOn(memmap.MappableRange{offset, offset + uint64(ar.Length())}) +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (fd *specialFileFD) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { + return fd.AddMapping(ctx, ms, dstAR, offset, writable) +} + +// Translate implements memmap.Mappable.Translate. +func (fd *specialFileFD) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { + mr := optional + if fd.filesystem().opts.limitHostFDTranslation { + mr = maxFillRange(required, optional) + } + return []memmap.Translation{ + { + Source: mr, + File: fd, + Offset: mr.Start, + Perms: hostarch.AnyAccess, + }, + }, nil +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (fd *specialFileFD) InvalidateUnsavable(ctx context.Context) error { + return nil +} + +// IncRef implements memmap.File.IncRef. +func (fd *specialFileFD) IncRef(fr memmap.FileRange) { + fd.fileRefsMu.Lock() + defer fd.fileRefsMu.Unlock() + fd.fileRefs.IncRefAndAccount(fr) +} + +// DecRef implements memmap.File.DecRef. +func (fd *specialFileFD) DecRef(fr memmap.FileRange) { + fd.fileRefsMu.Lock() + defer fd.fileRefsMu.Unlock() + fd.fileRefs.DecRefAndAccount(fr) +} + +// MapInternal implements memmap.File.MapInternal. +func (fd *specialFileFD) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) { + fd.requireHostFD() + return fd.hostFileMapper.MapInternal(fr, int(fd.handle.fd), at.Write) +} + +// FD implements memmap.File.FD. +func (fd *specialFileFD) FD() int { + fd.requireHostFD() + return int(fd.handle.fd) +} + +func (fd *specialFileFD) requireHostFD() { + if fd.handle.fd < 0 { + // This is possible if fd was successfully mmapped before saving, then + // was restored without a host FD. This is unrecoverable: without a + // host FD, we can't mmap this file post-restore. + panic("gofer.specialFileFD can no longer be memory-mapped without a host FD") + } +} diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go index dbd834c67..27d9be5c4 100644 --- a/pkg/sentry/fsimpl/gofer/symlink.go +++ b/pkg/sentry/fsimpl/gofer/symlink.go @@ -35,7 +35,13 @@ func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { return target, nil } } - target, err := d.file.readlink(ctx) + var target string + var err error + if d.fs.opts.lisaEnabled { + target, err = d.controlFDLisa.ReadLinkAt(ctx) + } else { + target, err = d.file.readlink(ctx) + } if d.fs.opts.interop != InteropModeShared { if err == nil { d.haveTarget = true diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 9cbe805b9..07940b225 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -17,6 +17,7 @@ package gofer import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -24,6 +25,10 @@ func dentryTimestampFromP9(s, ns uint64) int64 { return int64(s*1e9 + ns) } +func dentryTimestampFromLisa(t linux.StatxTimestamp) int64 { + return t.Sec*1e9 + int64(t.Nsec) +} + // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime || mnt.ReadOnly() { diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 476545d00..180a35583 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -70,7 +70,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/unet", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 89aa7b3d9..984c6e8ee 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -37,7 +37,6 @@ import ( unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -712,7 +711,7 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts if total != 0 { err = nil } else { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } } return total, err @@ -766,7 +765,7 @@ func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opt if !i.seekable { n, err := f.writeToHostFD(ctx, src, -1, opts.Flags) if isBlockError(err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } return n, err } diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index 7f6ce4ee5..04ac73255 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -346,7 +345,7 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) // If the signal is SIGTTIN, then we are attempting to read // from the TTY. Don't send the signal and return EIO. if sig == linux.SIGTTIN { - return syserror.EIO + return linuxerr.EIO } // Otherwise, we are writing or changing terminal state. This is allowed. @@ -355,7 +354,7 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) // If the process group is an orphan, return EIO. if pg.IsOrphan() { - return syserror.EIO + return linuxerr.EIO } // Otherwise, send the signal to the process group and return ERESTARTSYS. @@ -368,5 +367,5 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) // // Linux ignores the result of kill_pgrp(). _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) - return syserror.ERESTARTSYS + return linuxerr.ERESTARTSYS } diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 95d7ebe2e..9850f3f41 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -42,7 +42,7 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { } // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. -// If so, they can be transformed into syserror.ErrWouldBlock. +// If so, they can be transformed into linuxerr.ErrWouldBlock. func isBlockError(err error) bool { return linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) } diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index d53937db6..4b577ea43 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -119,7 +119,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", ], ) @@ -137,6 +136,7 @@ go_test( "//pkg/abi/linux", "//pkg/context", "//pkg/errors/linuxerr", + "//pkg/fspath", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", @@ -144,7 +144,6 @@ go_test( "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 8b008dc10..7db1473c4 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -99,7 +98,7 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *Dentry, children *OrderedChildren, l func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) error { if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 { // Can't open directories for writing. - return syserror.EISDIR + return linuxerr.EISDIR } fd.LockFD.Init(locks) fd.seekEnd = fdOpts.SeekEnd diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index a97473f7d..363ebc466 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // stepExistingLocked resolves rp.Component() in parent directory vfsd. @@ -224,7 +223,7 @@ func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string return linuxerr.EEXIST } if parent.VFSDentry().IsDead() { - return syserror.ENOENT + return linuxerr.ENOENT } if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite); err != nil { return err @@ -241,7 +240,7 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) er return linuxerr.EBUSY } if parent.vfsd.IsDead() { - return syserror.ENOENT + return linuxerr.ENOENT } if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err @@ -362,7 +361,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return err } if rp.MustBeDir() { - return syserror.ENOENT + return linuxerr.ENOENT } if rp.Mount() != vd.Mount() { return linuxerr.EXDEV @@ -443,7 +442,7 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } if rp.MustBeDir() { - return syserror.ENOENT + return linuxerr.ENOENT } if err := rp.Mount().CheckBeginWrite(); err != nil { return err @@ -509,7 +508,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf defer unlock() if rp.Done() { if rp.MustBeDir() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if mustCreate { return nil, linuxerr.EEXIST @@ -536,11 +535,11 @@ afterTrailingSymlink: } // Reject attempts to open directories with O_CREAT. if rp.MustBeDir() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } pc := rp.Component() if pc == "." || pc == ".." { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if len(pc) > linux.NAME_MAX { return nil, linuxerr.ENAMETOOLONG @@ -861,7 +860,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ return err } if rp.MustBeDir() { - return syserror.ENOENT + return linuxerr.ENOENT } if err := rp.Mount().CheckBeginWrite(); err != nil { return err @@ -895,7 +894,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } if d.isDir() { - return syserror.EISDIR + return linuxerr.EISDIR } virtfs := rp.VirtualFilesystem() parentDentry := d.parent diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index a42fc79b4..b96dc9ef7 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -26,7 +26,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // InodeNoopRefCount partially implements the Inode interface, specifically the @@ -234,6 +233,11 @@ func (a *InodeAttrs) Mode() linux.FileMode { return linux.FileMode(atomic.LoadUint32(&a.mode)) } +// Links returns the link count. +func (a *InodeAttrs) Links() uint32 { + return atomic.LoadUint32(&a.nlink) +} + // TouchAtime updates a.atime to the current time. func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) { if mnt.Flags.NoATime || mnt.ReadOnly() { @@ -289,7 +293,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut return linuxerr.EPERM } if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { - return syserror.EISDIR + return linuxerr.EISDIR } if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { return err @@ -475,7 +479,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error s, ok := o.set[name] if !ok { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } s.inode.IncRef() // This ref is passed to the dentry upon creation via Init. @@ -502,6 +506,30 @@ func (o *OrderedChildren) Insert(name string, child Inode) error { return o.insert(name, child, false) } +// Inserter is like Insert, but obtains the child to insert by calling +// makeChild. makeChild is only called if the insert will succeed. This allows +// the caller to atomically check and insert a child without having to +// clean up the child on failure. +func (o *OrderedChildren) Inserter(name string, makeChild func() Inode) (Inode, error) { + o.mu.Lock() + defer o.mu.Unlock() + if _, ok := o.set[name]; ok { + return nil, linuxerr.EEXIST + } + + // Note: We must not fail after we call makeChild(). + + child := makeChild() + s := &slot{ + name: name, + inode: child, + static: false, + } + o.order.PushBack(s) + o.set[name] = s + return child, nil +} + // insert inserts child into o. // // Precondition: Caller must be holding a ref on child if static is true. @@ -559,7 +587,7 @@ func (o *OrderedChildren) replaceChildLocked(ctx context.Context, name string, n func (o *OrderedChildren) checkExistingLocked(name string, child Inode) error { s, ok := o.set[name] if !ok { - return syserror.ENOENT + return linuxerr.ENOENT } if s.inode != child { panic(fmt.Sprintf("Inode doesn't match what kernfs thinks! OrderedChild: %+v, kernfs: %+v", s.inode, child)) @@ -746,5 +774,5 @@ type InodeNoStatFS struct{} // StatFS implements Inode.StatFS. func (*InodeNoStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { - return linux.Statfs{}, syserror.ENOSYS + return linux.Statfs{}, linuxerr.ENOSYS } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 0e2867d49..544698694 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -61,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -542,6 +543,63 @@ func (d *Dentry) FSLocalPath() string { return b.String() } +// WalkDentryTree traverses p in the dentry tree for this filesystem. Note that +// this only traverses the dentry tree and is not a general path traversal. No +// symlinks and dynamic children are resolved, and no permission checks are +// performed. The caller is responsible for ensuring the returned Dentry exists +// for an appropriate lifetime. +// +// p is interpreted starting at d, and may be absolute or relative (absolute vs +// relative paths both refer to the same target here, since p is absolute from +// d). p may contain "." and "..", but will not allow traversal above d (similar +// to ".." at the root dentry). +// +// This is useful for filesystem internals, where the filesystem may not be +// mounted yet. For a mounted filesystem, use GetDentryAt. +func (d *Dentry) WalkDentryTree(ctx context.Context, vfsObj *vfs.VirtualFilesystem, p fspath.Path) (*Dentry, error) { + d.fs.mu.RLock() + defer d.fs.processDeferredDecRefs(ctx) + defer d.fs.mu.RUnlock() + + target := d + + for pit := p.Begin; pit.Ok(); pit = pit.Next() { + pc := pit.String() + + switch { + case target == nil: + return nil, linuxerr.ENOENT + case pc == ".": + // No-op, consume component and continue. + case pc == "..": + if target == d { + // Don't let .. traverse above the start point of the walk. + continue + } + target = target.parent + // Parent doesn't need revalidation since we revalidated it on the + // way to the child, and we're still holding fs.mu. + default: + var err error + + d.dirMu.Lock() + target, err = d.fs.revalidateChildLocked(ctx, vfsObj, target, pc, target.children[pc]) + d.dirMu.Unlock() + + if err != nil { + return nil, err + } + } + } + + if target == nil { + return nil, linuxerr.ENOENT + } + + target.IncRef() + return target, nil +} + // The Inode interface maps filesystem-level operations that operate on paths to // equivalent operations on specific filesystem nodes. // @@ -667,12 +725,15 @@ type inodeDirectory interface { // RmDir removes an empty child directory from this directory // inode. Implementations must update the parent directory's link count, // if required. Implementations are not responsible for checking that child - // is a directory, checking for an empty directory. + // is a directory, or checking for an empty directory. RmDir(ctx context.Context, name string, child Inode) error // Rename is called on the source directory containing an inode being - // renamed. child should point to the resolved child in the source - // directory. + // renamed. child points to the resolved child in the source directory. + // dstDir is guaranteed to be a directory inode. + // + // On a successful call to Rename, the caller updates the dentry tree to + // reflect the name change. // // Precondition: Caller must serialize concurrent calls to Rename. Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 609887943..a2aba9321 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" @@ -346,3 +347,63 @@ func TestDirFDIterDirents(t *testing.T) { "file1": linux.DT_REG, }) } + +func TestDirWalkDentryTree(t *testing.T) { + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), + "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), + "dir3": fs.newDir(ctx, creds, 0755, nil), + }), + }) + }) + defer sys.Destroy() + + testWalk := func(from *kernfs.Dentry, getDentryPath, walkPath string, expectedErr error) { + var d *kernfs.Dentry + if getDentryPath != "" { + pop := sys.PathOpAtRoot(getDentryPath) + vd := sys.GetDentryOrDie(pop) + defer vd.DecRef(sys.Ctx) + d = vd.Dentry().Impl().(*kernfs.Dentry) + } + + match, err := from.WalkDentryTree(sys.Ctx, sys.VFS, fspath.Parse(walkPath)) + if err == nil { + defer match.DecRef(sys.Ctx) + } + + if err != expectedErr { + t.Fatalf("WalkDentryTree from %q to %q (with expected error: %v) unexpected error, want: %v, got: %v", from.FSLocalPath(), walkPath, expectedErr, expectedErr, err) + } + if expectedErr != nil { + return + } + + if d != match { + t.Fatalf("WalkDentryTree from %q to %q (with expected error: %v) found unexpected dentry; want: %v, got: %v", from.FSLocalPath(), walkPath, expectedErr, d, match) + } + } + + rootD := sys.Root.Dentry().Impl().(*kernfs.Dentry) + + testWalk(rootD, "dir1", "/dir1", nil) + testWalk(rootD, "", "/dir-non-existent", linuxerr.ENOENT) + testWalk(rootD, "", "/dir1/child-non-existent", linuxerr.ENOENT) + testWalk(rootD, "", "/dir2/inner-non-existent/dir3", linuxerr.ENOENT) + + testWalk(rootD, "dir2/dir3", "/dir2/../dir2/dir3", nil) + testWalk(rootD, "dir2/dir3", "/dir2/././dir3", nil) + testWalk(rootD, "dir2/dir3", "/dir2/././dir3/.././dir3", nil) + + pop := sys.PathOpAtRoot("dir2") + dir2VD := sys.GetDentryOrDie(pop) + defer dir2VD.DecRef(sys.Ctx) + dir2D := dir2VD.Dentry().Impl().(*kernfs.Dentry) + + testWalk(dir2D, "dir2/dir3", "/dir3", nil) + testWalk(dir2D, "dir2/dir3", "/../../../dir3", nil) + testWalk(dir2D, "dir2/file1", "/file1", nil) + testWalk(dir2D, "dir2/file1", "file1", nil) +} diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index ed730e215..d16dfef9b 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -42,7 +42,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 1f85a1f0d..520487066 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) func (d *dentry) isCopiedUp() bool { @@ -37,6 +36,10 @@ func (d *dentry) isCopiedUp() bool { // // Preconditions: filesystem.renameMu must be locked. func (d *dentry) copyUpLocked(ctx context.Context) error { + return d.copyUpMaybeSyntheticMountpointLocked(ctx, false /* forSyntheticMountpoint */) +} + +func (d *dentry) copyUpMaybeSyntheticMountpointLocked(ctx context.Context, forSyntheticMountpoint bool) error { // Fast path. if d.isCopiedUp() { return nil @@ -60,7 +63,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { // d is a filesystem root with no upper layer. return linuxerr.EROFS } - if err := d.parent.copyUpLocked(ctx); err != nil { + if err := d.parent.copyUpMaybeSyntheticMountpointLocked(ctx, forSyntheticMountpoint); err != nil { return err } @@ -72,7 +75,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } if d.vfsd.IsDead() { // Raced with deletion of d. - return syserror.ENOENT + return linuxerr.ENOENT } // Obtain settable timestamps from the lower layer. @@ -169,7 +172,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { case linux.S_IFDIR: if err := vfsObj.MkdirAt(ctx, d.fs.creds, &newpop, &vfs.MkdirOptions{ - Mode: linux.FileMode(d.mode &^ linux.S_IFMT), + Mode: linux.FileMode(d.mode &^ linux.S_IFMT), + ForSyntheticMountpoint: forSyntheticMountpoint, }); err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 5e89928c5..3b3dcf836 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // _OVL_XATTR_PREFIX is an extended attribute key prefix to identify overlayfs @@ -314,7 +313,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str } if !topLookupLayer.existsInOverlay() { child.destroyLocked(ctx) - return nil, topLookupLayer, syserror.ENOENT + return nil, topLookupLayer, linuxerr.ENOENT } // Device and inode numbers were copied from the topmost layer above. Remap @@ -463,13 +462,21 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, return d, nil } +type createType int + +const ( + createNonDirectory createType = iota + createDirectory + createSyntheticMountpoint +) + // doCreateAt checks that creating a file at rp is permitted, then invokes // create to do so. // // Preconditions: // * !rp.Done(). // * For the final path component in rp, !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error { +func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, ct createType, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) @@ -483,7 +490,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return linuxerr.EEXIST } if parent.vfsd.IsDead() { - return syserror.ENOENT + return linuxerr.ENOENT } if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { @@ -505,8 +512,8 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return linuxerr.EEXIST } - if !dir && rp.MustBeDir() { - return syserror.ENOENT + if ct == createNonDirectory && rp.MustBeDir() { + return linuxerr.ENOENT } mnt := rp.Mount() @@ -519,7 +526,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir } // Ensure that the parent directory is copied-up so that we can create the // new file in the upper layer. - if err := parent.copyUpLocked(ctx); err != nil { + if err := parent.copyUpMaybeSyntheticMountpointLocked(ctx, ct == createSyntheticMountpoint); err != nil { return err } @@ -530,7 +537,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir parent.dirents = nil ev := linux.IN_CREATE - if dir { + if ct != createNonDirectory { ev |= linux.IN_ISDIR } parent.watches.Notify(ctx, name, uint32(ev), 0 /* cookie */, vfs.InodeEvent, false /* unlinked */) @@ -619,7 +626,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error { + return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error { if rp.Mount() != vd.Mount() { return linuxerr.EXDEV } @@ -672,7 +679,11 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error { + ct := createDirectory + if opts.ForSyntheticMountpoint { + ct = createSyntheticMountpoint + } + return fs.doCreateAt(ctx, rp, ct, func(parent *dentry, childName string, haveUpperWhiteout bool) error { vfsObj := fs.vfsfs.VirtualFilesystem() pop := vfs.PathOperation{ Root: parent.upperVD, @@ -723,7 +734,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error { + return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error { // Disallow attempts to create whiteouts. if opts.Mode&linux.S_IFMT == linux.S_IFCHR && opts.DevMajor == 0 && opts.DevMinor == 0 { return linuxerr.EPERM @@ -780,7 +791,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf start := rp.Start().Impl().(*dentry) if rp.Done() { if mayCreate && rp.MustBeDir() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if mustCreate { return nil, linuxerr.EEXIST @@ -807,7 +818,7 @@ afterTrailingSymlink: } // Reject attempts to open directories with O_CREAT. if mayCreate && rp.MustBeDir() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } // Determine whether or not we need to create a file. parent.dirMu.Lock() @@ -865,11 +876,11 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts * if ftype == linux.S_IFDIR { // Can't open directories with O_CREAT. if opts.Flags&linux.O_CREAT != 0 { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } // Can't open directories writably. if ats.MayWrite() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if opts.Flags&linux.O_DIRECT != 0 { return nil, linuxerr.EINVAL @@ -919,7 +930,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving return nil, err } if parent.vfsd.IsDead() { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { @@ -1086,7 +1097,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa defer newParent.dirMu.Unlock() } if newParent.vfsd.IsDead() { - return syserror.ENOENT + return linuxerr.ENOENT } var ( replaced *dentry @@ -1105,7 +1116,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { - return syserror.EISDIR + return linuxerr.EISDIR } if genericIsAncestorDentry(replaced, renamed) { return linuxerr.ENOTEMPTY @@ -1477,7 +1488,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error { + return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error { vfsObj := fs.vfsfs.VirtualFilesystem() pop := vfs.PathOperation{ Root: parent.upperVD, @@ -1533,7 +1544,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error defer rp.Mount().EndWrite() name := rp.Component() if name == "." || name == ".." { - return syserror.EISDIR + return linuxerr.EISDIR } if rp.MustBeDir() { return linuxerr.ENOTDIR @@ -1557,7 +1568,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } if child.isDir() { - return syserror.EISDIR + return linuxerr.EISDIR } if err := parent.mayDelete(rp.Credentials(), child); err != nil { return err diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 1d3d2d95f..95cfbdc42 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -102,7 +102,6 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/tcpip/header", "//pkg/tcpip/network/ipv4", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index d99f90b36..e04ae6660 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // subtasksInode represents the inode for /proc/[pid]/task/ directory. @@ -71,15 +70,15 @@ func (fs *filesystem) newSubtasks(ctx context.Context, task *kernel.Task, pidns func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) { tid, err := strconv.ParseUint(name, 10, 32) if err != nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } subTask := i.pidns.TaskWithID(kernel.ThreadID(tid)) if subTask == nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } if subTask.ThreadGroup() != i.task.ThreadGroup() { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } return i.fs.newTaskInode(ctx, subTask, i.pidns, false, i.cgroupControllers) } @@ -88,7 +87,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { - return offset, syserror.ENOENT + return offset, linuxerr.ENOENT } if relOffset >= int64(len(tasks)) { return offset, nil @@ -124,7 +123,7 @@ type subtasksFD struct { func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { if fd.task.ExitState() >= kernel.TaskExitZombie { - return syserror.ENOENT + return linuxerr.ENOENT } return fd.GenericDirectoryFD.IterDirents(ctx, cb) } @@ -132,7 +131,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { if fd.task.ExitState() >= kernel.TaskExitZombie { - return 0, syserror.ENOENT + return 0, linuxerr.ENOENT } return fd.GenericDirectoryFD.Seek(ctx, offset, whence) } @@ -140,7 +139,7 @@ func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *subtasksFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { if fd.task.ExitState() >= kernel.TaskExitZombie { - return linux.Statx{}, syserror.ENOENT + return linux.Statx{}, linuxerr.ENOENT } return fd.GenericDirectoryFD.Stat(ctx, opts) } @@ -148,7 +147,7 @@ func (fd *subtasksFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Sta // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { if fd.task.ExitState() >= kernel.TaskExitZombie { - return syserror.ENOENT + return linuxerr.ENOENT } return fd.GenericDirectoryFD.SetStat(ctx, opts) } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index dfc0a924e..5c6412fc0 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -22,11 +22,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) { @@ -142,11 +142,11 @@ func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.Ite func (i *fdDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) { fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } fd := int32(fdInt) if !taskFDExists(ctx, i.fs, i.task, fd) { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } return i.fs.newFDSymlink(ctx, i.task, fd, i.fs.NextIno()), nil } @@ -218,7 +218,7 @@ func (fs *filesystem) newFDSymlink(ctx context.Context, task *kernel.Task, fd in func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { file, _ := getTaskFD(s.task, s.fd) if file == nil { - return "", syserror.ENOENT + return "", linuxerr.ENOENT } defer s.fs.SafeDecRefFD(ctx, file) root := vfs.RootFromContext(ctx) @@ -231,7 +231,7 @@ func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { file, _ := getTaskFD(s.task, s.fd) if file == nil { - return vfs.VirtualDentry{}, "", syserror.ENOENT + return vfs.VirtualDentry{}, "", linuxerr.ENOENT } defer s.fs.SafeDecRefFD(ctx, file) vd := file.VirtualDentry() @@ -278,11 +278,11 @@ func (fs *filesystem) newFDInfoDirInode(ctx context.Context, task *kernel.Task) func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) { fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } fd := int32(fdInt) if !taskFDExists(ctx, i.fs, i.task, fd) { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } data := &fdInfoData{ fs: i.fs, @@ -330,7 +330,7 @@ var _ dynamicInode = (*fdInfoData)(nil) func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { file, descriptorFlags := getTaskFD(d.task, d.fd) if file == nil { - return syserror.ENOENT + return linuxerr.ENOENT } defer d.fs.SafeDecRefFD(ctx, file) // TODO(b/121266871): Include pos, locks, and other data. For now we only diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 0ce3ed797..d3f9cf489 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -33,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -41,7 +40,7 @@ import ( // Linux 3.18, the limit is five lines." - user_namespaces(7) const maxIDMapLines = 5 -// mm gets the kernel task's MemoryManager. No additional reference is taken on +// getMM gets the kernel task's MemoryManager. No additional reference is taken on // mm here. This is safe because MemoryManager.destroy is required to leave the // MemoryManager in a state where it's still usable as a DynamicBytesSource. func getMM(task *kernel.Task) *mm.MemoryManager { @@ -491,7 +490,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64 return int64(n), nil } if readErr != nil { - return 0, syserror.EIO + return 0, linuxerr.EIO } return 0, nil } @@ -609,12 +608,10 @@ func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime()))) var vss, rss uint64 - s.task.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) + if mm := getMM(s.task); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } fmt.Fprintf(buf, "%d %d ", vss, rss/hostarch.PageSize) // rsslim. @@ -650,13 +647,10 @@ var _ dynamicInode = (*statmData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { var vss, rss uint64 - s.task.WithMuLocked(func(t *kernel.Task) { - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - } - }) - + if mm := getMM(s.task); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize) return nil } @@ -780,12 +774,12 @@ func (s *statusFD) Generate(ctx context.Context, buf *bytes.Buffer) error { if fdTable := t.FDTable(); fdTable != nil { fds = fdTable.CurrentMaxFDs() } - if mm := t.MemoryManager(); mm != nil { - vss = mm.VirtualMemorySize() - rss = mm.ResidentSetSize() - data = mm.VirtualDataSize() - } }) + if mm := getMM(s.task); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + data = mm.VirtualDataSize() + } // Filesystem user/group IDs aren't implemented; effective UID/GID are used // instead. fmt.Fprintf(buf, "Uid:\t%d\t%d\t%d\t%d\n", ruid, euid, suid, euid) @@ -946,25 +940,17 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent return vfs.VirtualDentry{}, "", err } - var err error - var exec fsbridge.File - s.task.WithMuLocked(func(t *kernel.Task) { - mm := t.MemoryManager() - if mm == nil { - err = linuxerr.EACCES - return - } + mm := getMM(s.task) + if mm == nil { + return vfs.VirtualDentry{}, "", linuxerr.EACCES + } - // The MemoryManager may be destroyed, in which case - // MemoryManager.destroy will simply set the executable to nil - // (with locks held). - exec = mm.Executable() - if exec == nil { - err = linuxerr.ESRCH - } - }) - if err != nil { - return vfs.VirtualDentry{}, "", err + // The MemoryManager may be destroyed, in which case + // MemoryManager.destroy will simply set the executable to nil + // (with locks held). + exec := mm.Executable() + if exec == nil { + return vfs.VirtualDentry{}, "", linuxerr.ESRCH } defer exec.DecRef(ctx) diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index cf905fae4..7b0be9c14 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -21,11 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -116,12 +116,12 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err case threadSelfName: return i.newThreadSelfSymlink(ctx, root), nil } - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } task := i.pidns.TaskWithID(kernel.ThreadID(tid)) if task == nil { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers) @@ -268,6 +268,6 @@ func cpuInfoData(k *kernel.Kernel) string { return buf.String() } -func shmData(v uint64) dynamicInode { +func ipcData(v uint64) dynamicInode { return newStaticFile(strconv.FormatUint(v, 10)) } diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 03bed22a3..4d3a2f7e6 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // +stateify savable @@ -58,7 +57,7 @@ func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error } tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) if tgid == 0 { - return "", syserror.ENOENT + return "", linuxerr.ENOENT } return strconv.FormatUint(uint64(tgid), 10), nil } @@ -100,7 +99,7 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) tid := s.pidns.IDOfTask(t) if tid == 0 || tgid == 0 { - return "", syserror.ENOENT + return "", linuxerr.ENOENT } return fmt.Sprintf("%d/task/%d", tgid, tid), nil } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 99f64a9d8..82e2857b3 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -47,9 +47,12 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k * "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), - "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), - "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), - "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), + "shmall": fs.newInode(ctx, root, 0444, ipcData(linux.SHMALL)), + "shmmax": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMAX)), + "shmmni": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMNI)), + "msgmni": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNI)), + "msgmax": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMAX)), + "msgmnb": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNB)), "yama": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "ptrace_scope": fs.newYAMAPtraceScopeFile(ctx, k, root), }), diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD index adb610213..403c6f254 100644 --- a/pkg/sentry/fsimpl/signalfd/BUILD +++ b/pkg/sentry/fsimpl/signalfd/BUILD @@ -9,10 +9,10 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/kernel", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index a7f5928b7..bdb03ef96 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -18,10 +18,10 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -91,7 +91,7 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen info, err := sfd.target.Sigtimedwait(sfd.Mask(), 0) if err != nil { // There must be no signal available. - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // Copy out the signal info using the specified format. diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 1af0a5cbc..ab21f028e 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -36,7 +36,6 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index f322d2747..7fcb2d26b 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -84,6 +84,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + k := kernel.KernelFromContext(ctx) + fsDirChildren := make(map[string]kernfs.Inode) + // Create an empty directory to serve as the mount point for cgroupfs when + // cgroups are available. This emulates Linux behaviour, see + // kernel/cgroup.c:cgroup_init(). Note that in Linux, userspace (typically + // the init process) is ultimately responsible for actually mounting + // cgroupfs, but the kernel creates the mountpoint. For the sentry, the + // launcher mounts cgroupfs. + if k.CgroupRegistry() != nil { + fsDirChildren["cgroup"] = fs.newDir(ctx, creds, defaultSysDirMode, nil) + } + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), @@ -97,7 +109,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt }), }), "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), - "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, fsDirChildren), "kernel": kernelDir(ctx, fs, creds), "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 0a0d914cc..0c46a3a13 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -87,3 +87,17 @@ func TestSysRootContainsExpectedEntries(t *testing.T) { "power": linux.DT_DIR, }) } + +func TestCgroupMountpointExists(t *testing.T) { + // Note: The mountpoint is only created if cgroups are available. This is + // the VFS2 implementation of sysfs and the test runs with VFS2 enabled, so + // we expect to see the mount point unconditionally. + s := newTestSystem(t) + defer s.Destroy() + pop := s.PathOpAtRoot("/fs") + s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ + "cgroup": linux.DT_DIR, + }) + pop = s.PathOpAtRoot("/fs/cgroup") + s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ /*empty*/ }) +} diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD index e6980a314..2b83d7d9a 100644 --- a/pkg/sentry/fsimpl/timerfd/BUILD +++ b/pkg/sentry/fsimpl/timerfd/BUILD @@ -12,7 +12,6 @@ go_library( "//pkg/hostarch", "//pkg/sentry/kernel/time", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go index 655a1c76a..68b785791 100644 --- a/pkg/sentry/fsimpl/timerfd/timerfd.go +++ b/pkg/sentry/fsimpl/timerfd/timerfd.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -82,7 +81,7 @@ func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequenc } return sizeofUint64, nil } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // Clock returns the timer fd's Clock. diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index dc8b9bfeb..94486bb63 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -82,7 +82,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sentry/vfs/memxattr", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", ], ) @@ -125,7 +124,6 @@ go_test( "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 8b04df038..e067f136e 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -75,7 +74,7 @@ afterSymlink: } child, ok := dir.childMap[name] if !ok { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err @@ -171,12 +170,12 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return linuxerr.EEXIST } if !dir && rp.MustBeDir() { - return syserror.ENOENT + return linuxerr.ENOENT } // tmpfs never calls VFS.InvalidateDentry(), so parentDir.dentry can only // be dead if it was deleted. if parentDir.dentry.vfsd.IsDead() { - return syserror.ENOENT + return linuxerr.ENOENT } mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { @@ -258,7 +257,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return err } if i.nlink == 0 { - return syserror.ENOENT + return linuxerr.ENOENT } if i.nlink == maxLinks { return linuxerr.EMLINK @@ -345,7 +344,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if rp.Done() { // Reject attempts to open mount root directory with O_CREAT. if rp.MustBeDir() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if mustCreate { return nil, linuxerr.EEXIST @@ -366,11 +365,11 @@ afterTrailingSymlink: } // Reject attempts to open directories with O_CREAT. if rp.MustBeDir() { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } name := rp.Component() if name == "." || name == ".." { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } if len(name) > linux.NAME_MAX { return nil, linuxerr.ENAMETOOLONG @@ -457,7 +456,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open case *directory: // Can't open directories writably. if ats&vfs.MayWrite != 0 { - return nil, syserror.EISDIR + return nil, linuxerr.EISDIR } var fd directoryFD fd.LockFD.Init(&d.inode.locks) @@ -532,7 +531,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } renamed, ok := oldParentDir.childMap[oldName] if !ok { - return syserror.ENOENT + return linuxerr.ENOENT } if err := oldParentDir.mayDelete(rp.Credentials(), renamed); err != nil { return err @@ -567,7 +566,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa replacedDir, ok := replaced.inode.impl.(*directory) if ok { if !renamed.inode.isDir() { - return syserror.EISDIR + return linuxerr.EISDIR } if len(replacedDir.childMap) != 0 { return linuxerr.ENOTEMPTY @@ -588,7 +587,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // tmpfs never calls VFS.InvalidateDentry(), so newParentDir.dentry can // only be dead if it was deleted. if newParentDir.dentry.vfsd.IsDead() { - return syserror.ENOENT + return linuxerr.ENOENT } // Linux places this check before some of those above; we do it here for @@ -654,7 +653,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error } child, ok := parentDir.childMap[name] if !ok { - return syserror.ENOENT + return linuxerr.ENOENT } if err := parentDir.mayDelete(rp.Credentials(), child); err != nil { return err @@ -754,17 +753,17 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } name := rp.Component() if name == "." || name == ".." { - return syserror.EISDIR + return linuxerr.EISDIR } child, ok := parentDir.childMap[name] if !ok { - return syserror.ENOENT + return linuxerr.ENOENT } if err := parentDir.mayDelete(rp.Credentials(), child); err != nil { return err } if child.inode.isDir() { - return syserror.EISDIR + return linuxerr.EISDIR } if rp.MustBeDir() { return linuxerr.ENOTDIR diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index 418c7994e..99afd9817 100644 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -202,7 +201,7 @@ func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { readData := make([]byte, 1) dst := usermem.BytesIOSequence(readData) bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err) } if bytesRead != 0 { diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go index 4393cc13b..cb7711b39 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go @@ -21,10 +21,10 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -146,7 +146,7 @@ func TestLocks(t *testing.T) { if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if got, want := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil), linuxerr.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want) } if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil { @@ -165,7 +165,7 @@ func TestLocks(t *testing.T) { if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.WriteLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if got, want := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil), linuxerr.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want) } if err := fd.Impl().UnlockPOSIX(ctx, uid1, lock.LockRange{Start: 0, End: 1}); err != nil { diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index f2250c025..feafb06e4 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -44,7 +44,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // Name is the default filesystem name. @@ -556,7 +555,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs. needsCtimeBump = true } case *directory: - return syserror.EISDIR + return linuxerr.EISDIR default: return linuxerr.EINVAL } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 1d855234c..c12abdf33 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -1,10 +1,24 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "dentry_list", + out = "dentry_list.go", + package = "verity", + prefix = "dentry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*dentry", + "Linker": "*dentry", + }, +) + go_library( name = "verity", srcs = [ + "dentry_list.go", "filesystem.go", "save_restore.go", "verity.go", @@ -28,7 +42,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 930016a3e..52d47994d 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -32,7 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -67,40 +66,23 @@ func putDentrySlice(ds *[]*dentry) { dentrySlicePool.Put(ds) } -// renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls -// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for +// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls +// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for // writing. // // ds is a pointer-to-pointer since defer evaluates its arguments immediately, // but dentry slices are allocated lazily, and it's much easier to say "defer -// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { -// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. +// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { +// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. // +checklocksrelease:fs.renameMu -func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { +func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { fs.renameMu.RUnlock() if *ds == nil { return } - if len(**ds) != 0 { - fs.renameMu.Lock() - for _, d := range **ds { - d.checkDropLocked(ctx) - } - fs.renameMu.Unlock() - } - putDentrySlice(*ds) -} - -// +checklocksrelease:fs.renameMu -func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { - if *ds == nil { - fs.renameMu.Unlock() - return - } for _, d := range **ds { - d.checkDropLocked(ctx) + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } - fs.renameMu.Unlock() putDentrySlice(*ds) } @@ -166,7 +148,7 @@ afterSymlink: // verifyChildLocked verifies the hash of child against the already verified // hash of the parent to ensure the child is expected. verifyChild triggers a // sentry panic if unexpected modifications to the file system are detected. In -// ErrorOnViolation mode it returns a syserror instead. +// ErrorOnViolation mode it returns a linuxerr instead. // // Preconditions: // * fs.renameMu must be locked. @@ -547,7 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, if parent.verityEnabled() { if _, ok := parent.childrenNames[name]; !ok { - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } } @@ -595,23 +577,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } } - // Clear the Merkle tree file if they are to be generated at runtime. - // TODO(b/182315468): Optimize the Merkle tree generate process to - // allow only updating certain files/directories. - if fs.allowRuntimeEnable { - childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ - Root: childMerkleVD, - Start: childMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_TRUNC, - Mode: 0644, - }) - if err != nil { - return nil, err - } - childMerkleFD.DecRef(ctx) - } - // The dentry needs to be cleaned up if any error occurs. IncRef will be // called if a verity child dentry is successfully created. defer childMerkleVD.DecRef(ctx) @@ -718,7 +683,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds } var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err @@ -730,7 +695,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -751,7 +716,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -788,7 +753,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) if rp.Done() { @@ -970,7 +935,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err @@ -1000,7 +965,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statx{}, err @@ -1046,7 +1011,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) if _, err := fs.resolveLocked(ctx, rp, &ds); err != nil { return nil, err } @@ -1057,7 +1022,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -1073,7 +1038,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index c5fa9855b..d2526263c 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -23,10 +23,12 @@ // Lock order: // // filesystem.renameMu -// dentry.dirMu -// fileDescription.mu -// filesystem.verityMu -// dentry.hashMu +// dentry.cachingMu +// filesystem.cacheMu +// dentry.dirMu +// fileDescription.mu +// filesystem.verityMu +// dentry.hashMu // // Locking dentry.dirMu in multiple dentries requires that parent dentries are // locked before child dentries, and that filesystem.renameMu is locked to @@ -60,7 +62,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -97,6 +98,9 @@ const ( // sizeOfStringInt32 is the size for a 32 bit integer stored as string in // extended attributes. The maximum value of a 32 bit integer has 10 digits. sizeOfStringInt32 = 10 + + // defaultMaxCachedDentries is the default limit of dentry cache. + defaultMaxCachedDentries = uint64(1000) ) var ( @@ -107,9 +111,10 @@ var ( // Mount option names for verityfs. const ( - moptLowerPath = "lower_path" - moptRootHash = "root_hash" - moptRootName = "root_name" + moptLowerPath = "lower_path" + moptRootHash = "root_hash" + moptRootName = "root_name" + moptDentryCacheLimit = "dentry_cache_limit" ) // HashAlgorithm is a type specifying the algorithm used to hash the file @@ -189,6 +194,17 @@ type filesystem struct { // dentries. renameMu sync.RWMutex `state:"nosave"` + // cachedDentries contains all dentries with 0 references. (Due to race + // conditions, it may also contain dentries with non-zero references.) + // cachedDentriesLen is the number of dentries in cachedDentries. These + // fields are protected by cacheMu. + cacheMu sync.Mutex `state:"nosave"` + cachedDentries dentryList + cachedDentriesLen uint64 + + // maxCachedDentries is the maximum size of filesystem.cachedDentries. + maxCachedDentries uint64 + // verityMu synchronizes enabling verity files, protects files or // directories from being enabled by different threads simultaneously. // It also ensures that verity does not access files that are being @@ -199,6 +215,10 @@ type filesystem struct { // is for the whole file system to ensure that no more than one file is // enabled the same time. verityMu sync.RWMutex `state:"nosave"` + + // released is nonzero once filesystem.Release has been called. It is accessed + // with atomic memory operations. + released int32 } // InternalFilesystemOptions may be passed as @@ -239,7 +259,7 @@ func (FilesystemType) Release(ctx context.Context) {} // mode, it returns EIO, otherwise it panic. func (fs *filesystem) alertIntegrityViolation(msg string) error { if fs.action == ErrorOnViolation { - return syserror.EIO + return linuxerr.EIO } panic(msg) } @@ -267,6 +287,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt delete(mopts, moptRootName) rootName = root } + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts[moptDentryCacheLimit]; ok { + delete(mopts, moptDentryCacheLimit) + maxCD, err := strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("verity.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str) + return nil, nil, linuxerr.EINVAL + } + maxCachedDentries = maxCD + } // Check for unparsed options. if len(mopts) != 0 { @@ -340,12 +370,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt action: iopts.Action, opts: opts.Data, allowRuntimeEnable: iopts.AllowRuntimeEnable, + maxCachedDentries: maxCachedDentries, } fs.vfsfs.Init(vfsObj, &fstype, fs) // Construct the root dentry. d := fs.newDentry() - d.refs = 1 + // Set the root's reference count to 2. One reference is returned to + // the caller, and the other is held by fs to prevent the root from + // being "cached" and subsequently evicted. + d.refs = 2 lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root()) lowerVD.IncRef() d.lowerVD = lowerVD @@ -520,7 +554,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { + atomic.StoreInt32(&fs.released, 1) fs.lowerMount.DecRef(ctx) + + fs.renameMu.Lock() + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // An extra reference was held by the filesystem on the root to prevent + // it from being cached/evicted. + fs.rootDentry.DecRef(ctx) } // MountOptions implements vfs.FilesystemImpl.MountOptions. @@ -534,6 +577,11 @@ func (fs *filesystem) MountOptions() string { type dentry struct { vfsd vfs.Dentry + // refs is the reference count. Each dentry holds a reference on its + // parent, even if disowned. When refs reaches 0, the dentry may be + // added to the cache or destroyed. If refs == -1, the dentry has + // already been destroyed. refs is accessed using atomic memory + // operations. refs int64 // fs is the owning filesystem. fs is immutable. @@ -588,13 +636,23 @@ type dentry struct { // is protected by hashMu. hashMu sync.RWMutex `state:"nosave"` hash []byte + + // cachingMu is used to synchronize concurrent dentry caching attempts on + // this dentry. + cachingMu sync.Mutex `state:"nosave"` + + // If cached is true, dentryEntry links dentry into + // filesystem.cachedDentries. cached and dentryEntry are protected by + // cachingMu. + cached bool + dentryEntry } // newDentry creates a new dentry representing the given verity file. The -// dentry initially has no references; it is the caller's responsibility to set -// the dentry's reference count and/or call dentry.destroy() as appropriate. -// The dentry is initially invalid in that it contains no underlying dentry; -// the caller is responsible for setting them. +// dentry initially has no references, but is not cached; it is the caller's +// responsibility to set the dentry's reference count and/or call +// dentry.destroy() as appropriate. The dentry is initially invalid in that it +// contains no underlying dentry; the caller is responsible for setting them. func (fs *filesystem) newDentry() *dentry { d := &dentry{ fs: fs, @@ -630,42 +688,23 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - r := atomic.AddInt64(&d.refs, -1) - if d.LogRefs() { - refsvfs2.LogDecRef(d, r) - } - if r == 0 { - d.fs.renameMu.Lock() - d.checkDropLocked(ctx) - d.fs.renameMu.Unlock() - } else if r < 0 { - panic("verity.dentry.DecRef() called without holding a reference") + if d.decRefNoCaching() == 0 { + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } } -func (d *dentry) decRefLocked(ctx context.Context) { +// decRefNoCaching decrements d's reference count without calling +// d.checkCachingLocked, even if d's reference count reaches 0; callers are +// responsible for ensuring that d.checkCachingLocked will be called later. +func (d *dentry) decRefNoCaching() int64 { r := atomic.AddInt64(&d.refs, -1) if d.LogRefs() { refsvfs2.LogDecRef(d, r) } - if r == 0 { - d.checkDropLocked(ctx) - } else if r < 0 { - panic("verity.dentry.decRefLocked() called without holding a reference") + if r < 0 { + panic("verity.dentry.decRefNoCaching() called without holding a reference") } -} - -// checkDropLocked should be called after d's reference count becomes 0 or it -// becomes deleted. -func (d *dentry) checkDropLocked(ctx context.Context) { - // Dentries with a positive reference count must be retained. Dentries - // with a negative reference count have already been destroyed. - if atomic.LoadInt64(&d.refs) != 0 { - return - } - // Refs is still zero; destroy it. - d.destroyLocked(ctx) - return + return r } // destroyLocked destroys the dentry. @@ -684,6 +723,12 @@ func (d *dentry) destroyLocked(ctx context.Context) { panic("verity.dentry.destroyLocked() called with references on the dentry") } + // Drop the reference held by d on its parent without recursively + // locking d.fs.renameMu. + if d.parent != nil && d.parent.decRefNoCaching() == 0 { + d.parent.checkCachingLocked(ctx, true /* renameMuWriteLocked */) + } + if d.lowerVD.Ok() { d.lowerVD.DecRef(ctx) } @@ -696,7 +741,6 @@ func (d *dentry) destroyLocked(ctx context.Context) { delete(d.parent.children, d.name) } d.parent.dirMu.Unlock() - d.parent.decRefLocked(ctx) } refsvfs2.Unregister(d) } @@ -735,6 +779,140 @@ func (d *dentry) OnZeroWatches(context.Context) { //TODO(b/159261227): Implement OnZeroWatches. } +// checkCachingLocked should be called after d's reference count becomes 0 or +// it becomes disowned. +// +// For performance, checkCachingLocked can also be called after d's reference +// count becomes non-zero, so that d can be removed from the LRU cache. This +// may help in reducing the size of the cache and hence reduce evictions. Note +// that this is not necessary for correctness. +// +// It may be called on a destroyed dentry. For example, +// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times +// for the same dentry when the dentry is visited more than once in the same +// operation. One of the calls may destroy the dentry, so subsequent calls will +// do nothing. +// +// Preconditions: d.fs.renameMu must be locked for writing if +// renameMuWriteLocked is true; it may be temporarily unlocked. +func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked bool) { + d.cachingMu.Lock() + refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + d.cachingMu.Unlock() + return + } + if refs > 0 { + // fs.cachedDentries is permitted to contain dentries with non-zero refs, + // which are skipped by fs.evictCachedDentryLocked() upon reaching the end + // of the LRU. But it is still beneficial to remove d from the cache as we + // are already holding d.cachingMu. Keeping a cleaner cache also reduces + // the number of evictions (which is expensive as it acquires fs.renameMu). + d.removeFromCacheLocked() + d.cachingMu.Unlock() + return + } + + if atomic.LoadInt32(&d.fs.released) != 0 { + d.cachingMu.Unlock() + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu to access d.parent. Lock it for writing as + // needed by d.destroyLocked() later. + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + } + if d.parent != nil { + d.parent.dirMu.Lock() + delete(d.parent.children, d.name) + d.parent.dirMu.Unlock() + } + d.destroyLocked(ctx) // +checklocksforce: see above. + return + } + + d.fs.cacheMu.Lock() + // If d is already cached, just move it to the front of the LRU. + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentries.PushFront(d) + d.fs.cacheMu.Unlock() + d.cachingMu.Unlock() + return + } + // Cache the dentry, then evict the least recently used cached dentry if + // the cache becomes over-full. + d.fs.cachedDentries.PushFront(d) + d.fs.cachedDentriesLen++ + d.cached = true + shouldEvict := d.fs.cachedDentriesLen > d.fs.maxCachedDentries + d.fs.cacheMu.Unlock() + d.cachingMu.Unlock() + + if shouldEvict { + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu for writing as needed by + // d.evictCachedDentryLocked(). + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + } + d.fs.evictCachedDentryLocked(ctx) // +checklocksforce: see above. + } +} + +// Preconditions: d.cachingMu must be locked. +func (d *dentry) removeFromCacheLocked() { + if d.cached { + d.fs.cacheMu.Lock() + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.fs.cacheMu.Unlock() + d.cached = false + } +} + +// Precondition: fs.renameMu must be locked for writing; it may be temporarily +// unlocked. +// +checklocks:fs.renameMu +func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } +} + +// Preconditions: +// * fs.renameMu must be locked for writing; it may be temporarily unlocked. +// +checklocks:fs.renameMu +func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + fs.cacheMu.Lock() + victim := fs.cachedDentries.Back() + fs.cacheMu.Unlock() + if victim == nil { + // fs.cachedDentries may have become empty between when it was + // checked and when we locked fs.cacheMu. + return + } + + victim.cachingMu.Lock() + victim.removeFromCacheLocked() + // victim.refs may have become non-zero from an earlier path resolution + // since it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) != 0 { + victim.cachingMu.Unlock() + return + } + if victim.parent != nil { + victim.parent.dirMu.Lock() + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + victim.parent.dirMu.Unlock() + } + victim.cachingMu.Unlock() + victim.destroyLocked(ctx) // +checklocksforce: owned as precondition, victim.fs == fs. +} + func (d *dentry) isSymlink() bool { return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK } @@ -1091,6 +1269,21 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { return 0, fd.d.fs.alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } + // Populate children names here. We cannot rely on the children + // dentries to populate parent dentry's children names, because the + // parent dentry may be destroyed before users enable verity if its ref + // count drops to zero. + if fd.d.isDir() { + if err := fd.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { + if dirent.Name != "." && dirent.Name != ".." { + fd.d.childrenNames[dirent.Name] = struct{}{} + } + return nil + })); err != nil { + return 0, err + } + } + hash, dataSize, err := fd.generateMerkleLocked(ctx) if err != nil { return 0, err @@ -1118,9 +1311,6 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { }); err != nil { return 0, err } - - // Add the current child's name to parent's childrenNames. - fd.d.parent.childrenNames[fd.d.name] = struct{}{} } // Record the size of the data being hashed for fd. @@ -1215,7 +1405,7 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. case linux.FS_IOC_GETFLAGS: return fd.verityFlags(ctx, args[2].Pointer()) default: - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS } } diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index 66fa1ad40..03c8e2f38 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -12,8 +12,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/fd", - "//pkg/hostarch", + "//pkg/eventfd", "//pkg/log", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go index 285ea9050..5df06a60f 100644 --- a/pkg/sentry/hostmm/hostmm.go +++ b/pkg/sentry/hostmm/hostmm.go @@ -21,9 +21,7 @@ import ( "os" "path" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/log" ) @@ -54,7 +52,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) } defer eventControlFile.Close() - eventFD, err := newEventFD() + eventFD, err := eventfd.Create() if err != nil { return nil, err } @@ -75,20 +73,11 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) const stopVal = 1 << 63 stopCh := make(chan struct{}) go func() { // S/R-SAFE: f provides synchronization if necessary - rw := fd.NewReadWriter(eventFD.FD()) - var buf [sizeofUint64]byte for { - n, err := rw.Read(buf[:]) + val, err := eventFD.Read() if err != nil { - if err == unix.EINTR { - continue - } panic(fmt.Sprintf("failed to read from memory pressure level eventfd: %v", err)) } - if n != sizeofUint64 { - panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64)) - } - val := hostarch.ByteOrder.Uint64(buf[:]) if val >= stopVal { // Assume this was due to the notifier's "destructor" (the // function returned by NotifyCurrentMemcgPressureCallback @@ -101,30 +90,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) } }() return func() { - rw := fd.NewReadWriter(eventFD.FD()) - var buf [sizeofUint64]byte - hostarch.ByteOrder.PutUint64(buf[:], stopVal) - for { - n, err := rw.Write(buf[:]) - if err != nil { - if err == unix.EINTR { - continue - } - panic(fmt.Sprintf("failed to write to memory pressure level eventfd: %v", err)) - } - if n != sizeofUint64 { - panic(fmt.Sprintf("short write to memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64)) - } - break - } + eventFD.Write(stopVal) <-stopCh }, nil } - -func newEventFD() (*fd.FD, error) { - f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0) - if e != 0 { - return nil, fmt.Errorf("failed to create eventfd: %v", e) - } - return fd.New(int(f)), nil -} diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 5bba9de0b..2363cec5f 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -1,13 +1,26 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) +go_template_instance( + name = "atomicptr_netns", + out = "atomicptr_netns_unsafe.go", + package = "inet", + prefix = "Namespace", + template = "//pkg/sync/atomicptr:generic_atomicptr", + types = { + "Value": "Namespace", + }, +) + go_library( name = "inet", srcs = [ + "atomicptr_netns_unsafe.go", "context.go", "inet.go", "namespace.go", diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e4e0dc04f..c0f13bf52 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -268,6 +268,7 @@ go_library( "//pkg/sentry/mm", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", + "//pkg/sentry/seccheck", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/time", @@ -281,7 +282,6 @@ go_library( "//pkg/state/wire", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/stack", "//pkg/usermem", diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 7a1a36454..9aa03f506 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -66,6 +66,5 @@ go_library( "//pkg/errors/linuxerr", "//pkg/log", "//pkg/sync", - "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go index c93ef6ac1..a0e291f58 100644 --- a/pkg/sentry/kernel/cgroup.go +++ b/pkg/sentry/kernel/cgroup.go @@ -196,6 +196,7 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files // uniqueness of controllers enforced by Register, drop the // dying hierarchy now. The eventual unregister by the FS // teardown will become a no-op. + r.unregisterLocked(h.id) return nil } return h.fs diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index 564c3d42e..f240a68aa 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -9,13 +9,13 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fdnotifier", "//pkg/hostarch", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 4466fbc9d..5ea44a2c2 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -22,13 +22,13 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -145,7 +145,7 @@ func (e *EventOperations) hostRead(ctx context.Context, dst usermem.IOSequence) if _, err := unix.Read(e.hostfd, buf[:]); err != nil { if err == unix.EWOULDBLOCK { - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } return err } @@ -165,7 +165,7 @@ func (e *EventOperations) read(ctx context.Context, dst usermem.IOSequence) erro // We can't complete the read if the value is currently zero. if e.val == 0 { e.mu.Unlock() - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } // Update the value based on the mode the event is operating in. @@ -198,7 +198,7 @@ func (e *EventOperations) hostWrite(val uint64) error { hostarch.ByteOrder.PutUint64(buf[:], val) _, err := unix.Write(e.hostfd, buf[:]) if err == unix.EWOULDBLOCK { - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } return err } @@ -230,7 +230,7 @@ func (e *EventOperations) Signal(val uint64) error { // uint64 minus 1. if val > math.MaxUint64-1-e.val { e.mu.Unlock() - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } e.val += val diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index cfdea5cf7..c897e3a5f 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -42,7 +42,6 @@ go_library( "//pkg/log", "//pkg/sentry/memmap", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index f5c364c96..2c9ea65aa 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // KeyKind indicates the type of a Key. @@ -166,7 +165,7 @@ func atomicOp(t Target, addr hostarch.Addr, opIn uint32) (bool, error) { case linux.FUTEX_OP_XOR: newVal = oldVal ^ opArg default: - return false, syserror.ENOSYS + return false, linuxerr.ENOSYS } prev, err := t.CompareAndSwapUint32(addr, oldVal, newVal) if err != nil { @@ -192,7 +191,7 @@ func atomicOp(t Target, addr hostarch.Addr, opIn uint32) (bool, error) { case linux.FUTEX_OP_CMP_GE: return oldVal >= cmpArg, nil default: - return false, syserror.ENOSYS + return false, linuxerr.ENOSYS } } diff --git a/pkg/sentry/kernel/ipc/object.go b/pkg/sentry/kernel/ipc/object.go index 387b35e7e..facd157c7 100644 --- a/pkg/sentry/kernel/ipc/object.go +++ b/pkg/sentry/kernel/ipc/object.go @@ -19,6 +19,8 @@ package ipc import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) @@ -113,3 +115,36 @@ func (o *Object) CheckPermissions(creds *auth.Credentials, req fs.PermMask) bool } return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, o.UserNS) } + +// Set modifies attributes for an IPC object. See *ctl(IPC_SET). +// +// Precondition: Mechanism.mu must be held. +func (o *Object) Set(ctx context.Context, perm *linux.IPCPerm) error { + creds := auth.CredentialsFromContext(ctx) + uid := creds.UserNamespace.MapToKUID(auth.UID(perm.UID)) + gid := creds.UserNamespace.MapToKGID(auth.GID(perm.GID)) + if !uid.Ok() || !gid.Ok() { + // The man pages don't specify an errno for invalid uid/gid, but EINVAL + // is generally used for invalid arguments. + return linuxerr.EINVAL + } + + if !o.CheckOwnership(creds) { + // "The argument cmd has the value IPC_SET or IPC_RMID, but the + // effective user ID of the calling process is not the creator (as + // found in msg_perm.cuid) or the owner (as found in msg_perm.uid) + // of the message queue, and the caller is not privileged (Linux: + // does not have the CAP_SYS_ADMIN capability)." + return linuxerr.EPERM + } + + // User may only modify the lower 9 bits of the mode. All the other bits are + // always 0 for the underlying inode. + mode := linux.FileMode(perm.Mode & 0x1ff) + + o.Perms = fs.FilePermsFromMode(mode) + o.Owner.UID = uid + o.Owner.GID = gid + + return nil +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index df5160b67..f913d25db 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -78,11 +78,19 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow -// easy access everywhere. To be removed once VFS2 becomes the default. +// VFS2Enabled is set to true when VFS2 is enabled. Added as a global to allow +// easy access everywhere. +// +// TODO(gvisor.dev/issue/1624): Remove when VFS1 is no longer used. var VFS2Enabled = false -// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow +// LISAFSEnabled is set to true when lisafs protocol is enabled. Added as a +// global to allow easy access everywhere. +// +// TODO(gvisor.dev/issue/6319): Remove when lisafs is default. +var LISAFSEnabled = false + +// FUSEEnabled is set to true when FUSE is enabled. Added as a global to allow // easy access everywhere. To be removed once FUSE is completed. var FUSEEnabled = false diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go index fab396d7c..c7c5e41fb 100644 --- a/pkg/sentry/kernel/msgqueue/msgqueue.go +++ b/pkg/sentry/kernel/msgqueue/msgqueue.go @@ -129,6 +129,16 @@ type Message struct { Size uint64 } +func (m *Message) makeCopy() *Message { + new := &Message{ + Type: m.Type, + Size: m.Size, + } + new.Text = make([]byte, len(m.Text)) + copy(new.Text, m.Text) + return new +} + // Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves // as an abstracted version of kernel.Task. kernel.Task is not directly used to // prevent circular dependencies. @@ -206,6 +216,48 @@ func (r *Registry) FindByID(id ipc.ID) (*Queue, error) { return mech.(*Queue), nil } +// IPCInfo reports global parameters for message queues. See msgctl(IPC_INFO). +func (r *Registry) IPCInfo(ctx context.Context) *linux.MsgInfo { + return &linux.MsgInfo{ + MsgPool: linux.MSGPOOL, + MsgMap: linux.MSGMAP, + MsgMax: linux.MSGMAX, + MsgMnb: linux.MSGMNB, + MsgMni: linux.MSGMNI, + MsgSsz: linux.MSGSSZ, + MsgTql: linux.MSGTQL, + MsgSeg: linux.MSGSEG, + } +} + +// MsgInfo reports global parameters for message queues. See msgctl(MSG_INFO). +func (r *Registry) MsgInfo(ctx context.Context) *linux.MsgInfo { + r.mu.Lock() + defer r.mu.Unlock() + + var messages, bytes uint64 + r.reg.ForAllObjects( + func(o ipc.Mechanism) { + q := o.(*Queue) + q.mu.Lock() + messages += q.messageCount + bytes += q.byteCount + q.mu.Unlock() + }, + ) + + return &linux.MsgInfo{ + MsgPool: int32(r.reg.ObjectCount()), + MsgMap: int32(messages), + MsgTql: int32(bytes), + MsgMax: linux.MSGMAX, + MsgMnb: linux.MSGMNB, + MsgMni: linux.MSGMNI, + MsgSsz: linux.MSGSSZ, + MsgSeg: linux.MSGSEG, + } +} + // Send appends a message to the message queue, and returns an error if sending // fails. See msgsnd(2). func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) error { @@ -413,7 +465,7 @@ func (q *Queue) Copy(mType int64) (*Message, error) { if msg == nil { return nil, linuxerr.ENOMSG } - return msg, nil + return msg.makeCopy(), nil } // msgOfType returns the first message with the specified type, nil if no @@ -465,6 +517,73 @@ func (q *Queue) msgAtIndex(mType int64) *Message { return msg } +// Set modifies some values of the queue. See msgctl(IPC_SET). +func (q *Queue) Set(ctx context.Context, ds *linux.MsqidDS) error { + q.mu.Lock() + defer q.mu.Unlock() + + creds := auth.CredentialsFromContext(ctx) + if ds.MsgQbytes > maxQueueBytes && !creds.HasCapabilityIn(linux.CAP_SYS_RESOURCE, q.obj.UserNS) { + // "An attempt (IPC_SET) was made to increase msg_qbytes beyond the + // system parameter MSGMNB, but the caller is not privileged (Linux: + // does not have the CAP_SYS_RESOURCE capability)." + return linuxerr.EPERM + } + + if err := q.obj.Set(ctx, &ds.MsgPerm); err != nil { + return err + } + + q.maxBytes = ds.MsgQbytes + q.changeTime = ktime.NowFromContext(ctx) + return nil +} + +// Stat returns a MsqidDS object filled with information about the queue. See +// msgctl(IPC_STAT) and msgctl(MSG_STAT). +func (q *Queue) Stat(ctx context.Context) (*linux.MsqidDS, error) { + return q.stat(ctx, fs.PermMask{Read: true}) +} + +// StatAny is similar to Queue.Stat, but doesn't require read permission. See +// msgctl(MSG_STAT_ANY). +func (q *Queue) StatAny(ctx context.Context) (*linux.MsqidDS, error) { + return q.stat(ctx, fs.PermMask{}) +} + +// stat returns a MsqidDS object filled with information about the queue. An +// error is returned if the user doesn't have the specified permissions. +func (q *Queue) stat(ctx context.Context, mask fs.PermMask) (*linux.MsqidDS, error) { + q.mu.Lock() + defer q.mu.Unlock() + + creds := auth.CredentialsFromContext(ctx) + if !q.obj.CheckPermissions(creds, mask) { + // "The caller must have read permission on the message queue." + return nil, linuxerr.EACCES + } + + return &linux.MsqidDS{ + MsgPerm: linux.IPCPerm{ + Key: uint32(q.obj.Key), + UID: uint32(creds.UserNamespace.MapFromKUID(q.obj.Owner.UID)), + GID: uint32(creds.UserNamespace.MapFromKGID(q.obj.Owner.GID)), + CUID: uint32(creds.UserNamespace.MapFromKUID(q.obj.Creator.UID)), + CGID: uint32(creds.UserNamespace.MapFromKGID(q.obj.Creator.GID)), + Mode: uint16(q.obj.Perms.LinuxMode()), + Seq: 0, // IPC sequences not supported. + }, + MsgStime: q.sendTime.TimeT(), + MsgRtime: q.receiveTime.TimeT(), + MsgCtime: q.changeTime.TimeT(), + MsgCbytes: q.byteCount, + MsgQnum: q.messageCount, + MsgQbytes: q.maxBytes, + MsgLspid: q.sendPID, + MsgLrpid: q.receivePID, + }, nil +} + // Lock implements ipc.Mechanism.Lock. func (q *Queue) Lock() { q.mu.Lock() diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 94ebac7c5..5b2bac783 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", @@ -51,7 +50,6 @@ go_test( "//pkg/errors/linuxerr", "//pkg/sentry/contexttest", "//pkg/sentry/fs", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 08786d704..615591507 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // inodeOperations implements fs.InodeOperations for pipes. @@ -95,7 +94,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() { if !waitFor(&i.mu, &i.wWakeup, ctx) { r.DecRef(ctx) - return nil, syserror.ErrInterrupted + return nil, linuxerr.ErrInterrupted } } @@ -118,7 +117,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi if !waitFor(&i.mu, &i.rWakeup, ctx) { w.DecRef(ctx) - return nil, syserror.ErrInterrupted + return nil, linuxerr.ErrInterrupted } } return w, nil diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index d25cf658e..31bd7910a 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/syserror" ) type sleeper struct { @@ -240,7 +239,7 @@ func TestBlockedOpenIsCancellable(t *testing.T) { // If the cancel on the sleeper didn't work, the open for read would never // return. res := <-done - if res.error != syserror.ErrInterrupted { + if res.error != linuxerr.ErrInterrupted { t.Fatalf("Cancellation didn't cause GetFile to return fs.ErrInterrupted, got %v.", res.error) } diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 85e3ce9f4..86beee6fe 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -201,7 +200,7 @@ func (p *Pipe) peekLocked(count int64, f func(safemem.BlockSeq) (uint64, error)) if !p.HasWriters() { return 0, io.EOF } - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } count = p.size } @@ -250,7 +249,7 @@ func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error) avail := p.max - p.size if avail == 0 { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } short := false if count > avail { @@ -258,7 +257,7 @@ func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error) // (PIPE_BUF) be atomic, but requires no atomicity for writes // larger than this. if count <= atomicIOBytes { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } count = avail short = true @@ -307,7 +306,7 @@ func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error) // If we shortened the write, adjust the returned error appropriately. if short { - return done, syserror.ErrWouldBlock + return done, linuxerr.ErrWouldBlock } return done, nil diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go index 867f4a76b..aa3ab305d 100644 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ b/pkg/sentry/kernel/pipe/pipe_test.go @@ -18,8 +18,8 @@ import ( "bytes" "testing" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -51,8 +51,8 @@ func TestPipeReadBlock(t *testing.T) { defer w.DecRef(ctx) n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1))) - if n != 0 || err != syserror.ErrWouldBlock { - t.Fatalf("Readv: got (%d, %v), wanted (0, %v)", n, err, syserror.ErrWouldBlock) + if n != 0 || err != linuxerr.ErrWouldBlock { + t.Fatalf("Readv: got (%d, %v), wanted (0, %v)", n, err, linuxerr.ErrWouldBlock) } } @@ -67,7 +67,7 @@ func TestPipeWriteBlock(t *testing.T) { msg := make([]byte, capacity+1) n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) - if wantN, wantErr := int64(capacity), syserror.ErrWouldBlock; n != wantN || err != wantErr { + if wantN, wantErr := int64(capacity), linuxerr.ErrWouldBlock; n != wantN || err != wantErr { t.Fatalf("Writev: got (%d, %v), wanted (%d, %v)", n, err, wantN, wantErr) } } @@ -102,7 +102,7 @@ func TestPipeWriteUntilEnd(t *testing.T) { for { n, err := r.Readv(ctx, dst) dst = dst.DropFirst64(n) - if err == syserror.ErrWouldBlock { + if err == linuxerr.ErrWouldBlock { select { case <-ch: continue @@ -129,7 +129,7 @@ func TestPipeWriteUntilEnd(t *testing.T) { for src.NumBytes() != 0 { n, err := w.Writev(ctx, src) src = src.DropFirst64(n) - if err == syserror.ErrWouldBlock { + if err == linuxerr.ErrWouldBlock { <-ch continue } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 077d5fd7f..a6f1989f5 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -121,7 +120,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s // writer, we have to wait for a writer to open the other end. if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { fd.DecRef(ctx) - return nil, syserror.EINTR + return nil, linuxerr.EINTR } case writable: @@ -137,7 +136,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s // Wait for a reader to open the other end. if !waitFor(&vp.mu, &vp.rWakeup, ctx) { fd.DecRef(ctx) - return nil, syserror.EINTR + return nil, linuxerr.EINTR } } diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 079294f81..717c9a6b3 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -465,7 +464,7 @@ func (t *Task) ptraceUnfreezeLocked() { // stop. func (t *Task) ptraceUnstop(mode ptraceSyscallMode, singlestep bool, sig linux.Signal) error { if sig != 0 && !sig.IsValid() { - return syserror.EIO + return linuxerr.EIO } t.tg.pidns.owner.mu.Lock() defer t.tg.pidns.owner.mu.Unlock() @@ -532,7 +531,7 @@ func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error { } if seize { if err := target.ptraceSetOptionsLocked(opts); err != nil { - return syserror.EIO + return linuxerr.EIO } } target.ptraceTracer.Store(t) @@ -569,7 +568,7 @@ func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error { // ptrace stop. func (t *Task) ptraceDetach(target *Task, sig linux.Signal) error { if sig != 0 && !sig.IsValid() { - return syserror.EIO + return linuxerr.EIO } t.tg.pidns.owner.mu.Lock() defer t.tg.pidns.owner.mu.Unlock() @@ -967,7 +966,7 @@ func (t *Task) ptraceInterrupt(target *Task) error { return linuxerr.ESRCH } if !target.ptraceSeized { - return syserror.EIO + return linuxerr.EIO } target.tg.signalHandlers.mu.Lock() defer target.tg.signalHandlers.mu.Unlock() @@ -1030,7 +1029,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error { if req == linux.PTRACE_ATTACH || req == linux.PTRACE_SEIZE { seize := req == linux.PTRACE_SEIZE if seize && addr != 0 { - return syserror.EIO + return linuxerr.EIO } return t.ptraceAttach(target, seize, uintptr(data)) } @@ -1120,13 +1119,13 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() if !target.ptraceSeized { - return syserror.EIO + return linuxerr.EIO } if target.ptraceSiginfo == nil { - return syserror.EIO + return linuxerr.EIO } if target.ptraceSiginfo.Code>>8 != linux.PTRACE_EVENT_STOP { - return syserror.EIO + return linuxerr.EIO } target.tg.signalHandlers.mu.Lock() defer target.tg.signalHandlers.mu.Unlock() diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go index 63422e155..564add01b 100644 --- a/pkg/sentry/kernel/ptrace_amd64.go +++ b/pkg/sentry/kernel/ptrace_amd64.go @@ -19,8 +19,8 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -88,6 +88,6 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data hostarch.Addr) err return err default: - return syserror.EIO + return linuxerr.EIO } } diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go index 27514d67b..7c2b94339 100644 --- a/pkg/sentry/kernel/ptrace_arm64.go +++ b/pkg/sentry/kernel/ptrace_arm64.go @@ -18,11 +18,11 @@ package kernel import ( + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/syserror" ) // ptraceArch implements arch-specific ptrace commands. func (t *Task) ptraceArch(target *Task, req int64, addr, data hostarch.Addr) error { - return syserror.EIO + return linuxerr.EIO } diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index 54ca43c2e..0d66648c3 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -18,9 +18,9 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/syserror" ) const maxSyscallFilterInstructions = 1 << 15 @@ -176,7 +176,7 @@ func (t *Task) AppendSyscallFilter(p bpf.Program, syncAll bool) error { } if totalLength > maxSyscallFilterInstructions { - return syserror.ENOMEM + return linuxerr.ENOMEM } newFilters = append(newFilters, p) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 2ae08ed12..6aa74219e 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/sentry/kernel/ipc", "//pkg/sentry/kernel/time", "//pkg/sync", - "//pkg/syserror", ], ) @@ -43,9 +42,9 @@ go_test( deps = [ "//pkg/abi/linux", # keep "//pkg/context", # keep + "//pkg/errors/linuxerr", #keep "//pkg/sentry/contexttest", # keep "//pkg/sentry/kernel/auth", # keep "//pkg/sentry/kernel/ipc", # keep - "//pkg/syserror", # keep ], ) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 8610d3fc1..28e466948 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/ipc" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -151,10 +150,10 @@ func (r *Registry) FindOrCreate(ctx context.Context, key ipc.Key, nsems int32, m // Map reg.objects and map indexes in a registry are of the same size, // check map reg.objects only here for the system limit. if r.reg.ObjectCount() >= setsMax { - return nil, syserror.ENOSPC + return nil, linuxerr.ENOSPC } if r.totalSems() > int(semsTotalMax-nsems) { - return nil, syserror.ENOSPC + return nil, linuxerr.ENOSPC } // Finally create a new set. @@ -337,19 +336,15 @@ func (s *Set) Size() int { return len(s.sems) } -// Change changes some fields from the set atomically. -func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.FileOwner, perms fs.FilePermissions) error { +// Set modifies attributes for a semaphore set. See semctl(IPC_SET). +func (s *Set) Set(ctx context.Context, ds *linux.SemidDS) error { s.mu.Lock() defer s.mu.Unlock() - // "The effective UID of the calling process must match the owner or creator - // of the semaphore set, or the caller must be privileged." - if !s.obj.CheckOwnership(creds) { - return linuxerr.EACCES + if err := s.obj.Set(ctx, &ds.SemPerm); err != nil { + return err } - s.obj.Owner = owner - s.obj.Perms = perms s.changeTime = ktime.NowFromContext(ctx) return nil } @@ -549,7 +544,7 @@ func (s *Set) ExecuteOps(ctx context.Context, ops []linux.Sembuf, creds *auth.Cr // Did it race with a removal operation? if s.dead { - return nil, 0, syserror.EIDRM + return nil, 0, linuxerr.EIDRM } // Validate the operations. @@ -588,7 +583,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch if tmpVals[op.SemNum] != 0 { // Semaphore isn't 0, must wait. if op.SemFlg&linux.IPC_NOWAIT != 0 { - return nil, 0, syserror.ErrWouldBlock + return nil, 0, linuxerr.ErrWouldBlock } w := newWaiter(op.SemOp) @@ -604,7 +599,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch if -op.SemOp > tmpVals[op.SemNum] { // Not enough resources, must wait. if op.SemFlg&linux.IPC_NOWAIT != 0 { - return nil, 0, syserror.ErrWouldBlock + return nil, 0, linuxerr.ErrWouldBlock } w := newWaiter(op.SemOp) diff --git a/pkg/sentry/kernel/semaphore/semaphore_test.go b/pkg/sentry/kernel/semaphore/semaphore_test.go index 2e4ab8121..59ac92ef1 100644 --- a/pkg/sentry/kernel/semaphore/semaphore_test.go +++ b/pkg/sentry/kernel/semaphore/semaphore_test.go @@ -19,10 +19,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/ipc" - "gvisor.dev/gvisor/pkg/syserror" ) func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} { @@ -124,14 +124,14 @@ func TestNoWait(t *testing.T) { ops[0].SemOp = -2 ops[0].SemFlg = linux.IPC_NOWAIT - if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock { - t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock) + if _, _, err := set.executeOps(ctx, ops, 123); err != linuxerr.ErrWouldBlock { + t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, linuxerr.ErrWouldBlock) } ops[0].SemOp = 0 ops[0].SemFlg = linux.IPC_NOWAIT - if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock { - t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock) + if _, _, err := set.executeOps(ctx, ops, 123); err != linuxerr.ErrWouldBlock { + t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, linuxerr.ErrWouldBlock) } } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 4e8deac4c..2547957ba 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -42,7 +42,6 @@ go_library( "//pkg/sentry/pgalloc", "//pkg/sentry/usage", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 2abf467d7..ab938fa3c 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -49,7 +49,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // Registry tracks all shared memory segments in an IPC namespace. The registry @@ -151,7 +150,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key ipc.Key, siz if r.reg.ObjectCount() >= linux.SHMMNI { // "All possible shared memory IDs have been taken (SHMMNI) ..." // - man shmget(2) - return nil, syserror.ENOSPC + return nil, linuxerr.ENOSPC } if !private { @@ -184,7 +183,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key ipc.Key, siz // "... allocating a segment of the requested size would cause the // system to exceed the system-wide limit on shared memory (SHMALL)." // - man shmget(2) - return nil, syserror.ENOSPC + return nil, linuxerr.ENOSPC } // Need to create a new segment. @@ -521,7 +520,7 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr hostarch.Addr, opts Atta s.mu.Lock() defer s.mu.Unlock() if s.pendingDestruction && s.ReadRefs() == 0 { - return memmap.MMapOpts{}, syserror.EIDRM + return memmap.MMapOpts{}, linuxerr.EIDRM } creds := auth.CredentialsFromContext(ctx) @@ -619,25 +618,10 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error { s.mu.Lock() defer s.mu.Unlock() - creds := auth.CredentialsFromContext(ctx) - if !s.obj.CheckOwnership(creds) { - return linuxerr.EPERM - } - - uid := creds.UserNamespace.MapToKUID(auth.UID(ds.ShmPerm.UID)) - gid := creds.UserNamespace.MapToKGID(auth.GID(ds.ShmPerm.GID)) - if !uid.Ok() || !gid.Ok() { - return linuxerr.EINVAL + if err := s.obj.Set(ctx, &ds.ShmPerm); err != nil { + return err } - // User may only modify the lower 9 bits of the mode. All the other bits are - // always 0 for the underlying inode. - mode := linux.FileMode(ds.ShmPerm.Mode & 0x1ff) - s.obj.Perms = fs.FilePermsFromMode(mode) - - s.obj.Owner.UID = uid - s.obj.Owner.GID = gid - s.changeTime = ktime.NowFromContext(ctx) return nil } diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 1110ecca5..4180ca28e 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -15,7 +15,6 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 47958e2d4..9c5e6698c 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -99,7 +98,7 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS info, err := s.target.Sigtimedwait(s.Mask(), 0) if err != nil { // There must be no signal available. - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } // Copy out the signal info using the specified format. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 59eeb253d..b0004482c 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/sched" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/seccheck" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -510,9 +511,7 @@ type Task struct { numaNodeMask uint64 // netns is the task's network namespace. netns is never nil. - // - // netns is protected by mu. - netns *inet.Namespace + netns inet.NamespaceAtomicPtr // If rseqPreempted is true, before the next call to p.Switch(), // interrupt rseq critical regions as defined by rseqAddr and @@ -874,3 +873,23 @@ func (t *Task) ResetKcov() { t.kcov = nil } } + +// Preconditions: The TaskSet mutex must be locked. +func (t *Task) loadSeccheckInfoLocked(req seccheck.TaskFieldSet, mask *seccheck.TaskFieldSet, info *seccheck.TaskInfo) { + if req.Contains(seccheck.TaskFieldThreadID) { + info.ThreadID = int32(t.k.tasks.Root.tids[t]) + mask.Add(seccheck.TaskFieldThreadID) + } + if req.Contains(seccheck.TaskFieldThreadStartTime) { + info.ThreadStartTime = t.startTime + mask.Add(seccheck.TaskFieldThreadStartTime) + } + if req.Contains(seccheck.TaskFieldThreadGroupID) { + info.ThreadGroupID = int32(t.k.tasks.Root.tgids[t.tg]) + mask.Add(seccheck.TaskFieldThreadGroupID) + } + if req.Contains(seccheck.TaskFieldThreadGroupStartTime) { + info.ThreadGroupStartTime = t.tg.leader.startTime + mask.Add(seccheck.TaskFieldThreadGroupStartTime) + } +} diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index b2520eecf..9bfc155e4 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // BlockWithTimeout blocks t until an event is received from C, the application @@ -33,7 +32,7 @@ import ( // and is unspecified if haveTimeout is false. // // - An error which is nil if an event is received from C, ETIMEDOUT if the timeout -// expired, and syserror.ErrInterrupted if t is interrupted. +// expired, and linuxerr.ErrInterrupted if t is interrupted. // // Preconditions: The caller must be running on the task goroutine. func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.Duration) (time.Duration, error) { @@ -67,7 +66,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time. // application monotonic clock indicates a time of deadline (only if // haveDeadline is true), or t is interrupted. It returns nil if an event is // received from C, ETIMEDOUT if the deadline expired, and -// syserror.ErrInterrupted if t is interrupted. +// linuxerr.ErrInterrupted if t is interrupted. // // Preconditions: The caller must be running on the task goroutine. func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline ktime.Time) error { @@ -95,7 +94,7 @@ func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline // BlockWithTimer blocks t until an event is received from C or tchan, or t is // interrupted. It returns nil if an event is received from C, ETIMEDOUT if an -// event is received from tchan, and syserror.ErrInterrupted if t is +// event is received from tchan, and linuxerr.ErrInterrupted if t is // interrupted. // // Most clients should use BlockWithDeadline or BlockWithTimeout instead. @@ -106,7 +105,7 @@ func (t *Task) BlockWithTimer(C <-chan struct{}, tchan <-chan struct{}) error { } // Block blocks t until an event is received from C or t is interrupted. It -// returns nil if an event is received from C and syserror.ErrInterrupted if t +// returns nil if an event is received from C and linuxerr.ErrInterrupted if t // is interrupted. // // Preconditions: The caller must be running on the task goroutine. @@ -157,7 +156,7 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { region.End() t.SleepFinish(false) // Return the indicated error on interrupt. - return syserror.ErrInterrupted + return linuxerr.ErrInterrupted case <-timerChan: region.End() diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index da4b77ca2..a6d8fb163 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/seccheck" "gvisor.dev/gvisor/pkg/usermem" ) @@ -235,7 +236,23 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) { // nt that it must receive before its task goroutine starts running. tid := nt.k.tasks.Root.IDOfTask(nt) defer nt.Start(tid) - t.traceCloneEvent(tid) + + if seccheck.Global.Enabled(seccheck.PointClone) { + mask, info := getCloneSeccheckInfo(t, nt, args) + if err := seccheck.Global.Clone(t, mask, &info); err != nil { + // nt has been visible to the rest of the system since NewTask, so + // it may be blocking execve or a group stop, have been notified + // for group signal delivery, had children reparented to it, etc. + // Thus we can't just drop it on the floor. Instead, instruct the + // task goroutine to exit immediately, as quietly as possible. + nt.exitTracerNotified = true + nt.exitTracerAcked = true + nt.exitParentNotified = true + nt.exitParentAcked = true + nt.runState = (*runExitMain)(nil) + return 0, nil, err + } + } // "If fork/clone and execve are allowed by @prog, any child processes will // be constrained to the same filters and system call ABI as the parent." - @@ -260,6 +277,7 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) { ntid.CopyOut(t, hostarch.Addr(args.ParentTID)) } + t.traceCloneEvent(tid) kind := ptraceCloneKindClone if args.Flags&linux.CLONE_VFORK != 0 { kind = ptraceCloneKindVfork @@ -279,6 +297,22 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) { return ntid, nil, nil } +func getCloneSeccheckInfo(t, nt *Task, args *linux.CloneArgs) (seccheck.CloneFieldSet, seccheck.CloneInfo) { + req := seccheck.Global.CloneReq() + info := seccheck.CloneInfo{ + Credentials: t.Credentials(), + Args: *args, + } + var mask seccheck.CloneFieldSet + mask.Add(seccheck.CloneFieldCredentials) + mask.Add(seccheck.CloneFieldArgs) + t.k.tasks.mu.RLock() + defer t.k.tasks.mu.RUnlock() + t.loadSeccheckInfoLocked(req.Invoker, &mask.Invoker, &info.Invoker) + nt.loadSeccheckInfoLocked(req.Created, &mask.Created, &info.Created) + return mask, info +} + // maybeBeginVforkStop checks if a previously-started vfork child is still // running and has not yet released its MM, such that its parent t should enter // a vforkStop. @@ -410,7 +444,7 @@ func (t *Task) Unshare(flags int32) error { t.mu.Unlock() return linuxerr.EPERM } - t.netns = inet.NewNamespace(t.netns) + t.netns.Store(inet.NewNamespace(t.netns.Load())) } if flags&linux.CLONE_NEWUTS != 0 { if !haveCapSysAdmin { diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index cf8571262..db91fc4d8 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -66,10 +66,10 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // execStop is a TaskStop that a task sets on itself when it wants to execve @@ -97,7 +97,7 @@ func (t *Task) Execve(newImage *TaskImage) (*SyscallControl, error) { // We lost to a racing group-exit, kill, or exec from another thread // and should just exit. newImage.release() - return nil, syserror.EINTR + return nil, linuxerr.EINTR } // Cancel any racing group stops. @@ -222,9 +222,15 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { // Update credentials to reflect the execve. This should precede switching // MMs to ensure that dumpability has been reset first, if needed. t.updateCredsForExecLocked() - t.image.release() + oldImage := t.image t.image = *r.image t.mu.Unlock() + + // Don't hold t.mu while calling t.image.release(), that may + // attempt to acquire TaskImage.MemoryManager.mappingMu, a lock order + // violation. + oldImage.release() + t.unstopVforkParent() t.p.FullStateChanged() // NOTE(b/30316266): All locks must be dropped prior to calling Activate. diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index fbfcc19e5..b3931445b 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -32,7 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/waiter" ) @@ -230,9 +230,16 @@ func (*runExitMain) execute(t *Task) taskRunState { t.tg.pidns.owner.mu.Lock() t.updateRSSLocked() t.tg.pidns.owner.mu.Unlock() + + // Release the task image resources. Accessing these fields must be + // done with t.mu held, but the mm.DecUsers() call must be done outside + // of that lock. t.mu.Lock() - t.image.release() + mm := t.image.MemoryManager + t.image.MemoryManager = nil + t.image.fu = nil t.mu.Unlock() + mm.DecUsers(t) // Releasing the MM unblocks a blocked CLONE_VFORK parent. t.unstopVforkParent() @@ -859,7 +866,7 @@ func (t *Task) Wait(opts *WaitOptions) (*WaitResult, error) { return wr, err } if err := t.Block(ch); err != nil { - return wr, syserror.ConvertIntr(err, opts.BlockInterruptErr) + return wr, syserr.ConvertIntr(err, opts.BlockInterruptErr) } } } diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go index c132c27ef..6002ffb42 100644 --- a/pkg/sentry/kernel/task_image.go +++ b/pkg/sentry/kernel/task_image.go @@ -53,7 +53,7 @@ type TaskImage struct { } // release releases all resources held by the TaskImage. release is called by -// the task when it execs into a new TaskImage or exits. +// the task when it execs into a new TaskImage. func (image *TaskImage) release() { // Nil out pointers so that if the task is saved after release, it doesn't // follow the pointers to possibly now-invalid objects. diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index 8de08151a..f0c168ecc 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -191,9 +191,11 @@ const ( // // Preconditions: The task's owning TaskSet.mu must be locked. func (t *Task) updateInfoLocked() { - // Use the task's TID in the root PID namespace for logging. + // Use the task's TID and PID in the root PID namespace for logging. + pid := t.tg.pidns.owner.Root.tgids[t.tg] tid := t.tg.pidns.owner.Root.tids[t] - t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid)) + t.logPrefix.Store(fmt.Sprintf("[% 4d:% 4d] ", pid, tid)) + t.rebuildTraceContext(tid) } @@ -249,5 +251,9 @@ func (t *Task) traceExecEvent(image *TaskImage) { return } defer file.DecRef(t) - trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t)) + + // traceExecEvent function may be called before the task goroutine + // starts, so we must use the async context. + name := file.PathnameWithDeleted(t.AsyncContext()) + trace.Logf(t.traceContext, traceCategory, "exec: %s", name) } diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go index f7711232c..e31e2b2e8 100644 --- a/pkg/sentry/kernel/task_net.go +++ b/pkg/sentry/kernel/task_net.go @@ -20,9 +20,7 @@ import ( // IsNetworkNamespaced returns true if t is in a non-root network namespace. func (t *Task) IsNetworkNamespaced() bool { - t.mu.Lock() - defer t.mu.Unlock() - return !t.netns.IsRoot() + return !t.netns.Load().IsRoot() } // NetworkContext returns the network stack used by the task. NetworkContext @@ -31,14 +29,10 @@ func (t *Task) IsNetworkNamespaced() bool { // TODO(gvisor.dev/issue/1833): Migrate callers of this method to // NetworkNamespace(). func (t *Task) NetworkContext() inet.Stack { - t.mu.Lock() - defer t.mu.Unlock() - return t.netns.Stack() + return t.netns.Load().Stack() } // NetworkNamespace returns the network namespace observed by the task. func (t *Task) NetworkNamespace() *inet.Namespace { - t.mu.Lock() - defer t.mu.Unlock() - return t.netns + return t.netns.Load() } diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 054ff212f..7b336a46b 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -22,6 +22,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/goid" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -29,7 +30,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/syserror" ) // A taskRunState is a reified state in the task state machine. See README.md @@ -197,8 +197,8 @@ func (app *runApp) execute(t *Task) taskRunState { // a pending signal, causing another interruption, but that signal should // not interact with the interrupted syscall.) if t.haveSyscallReturn { - if sre, ok := syserror.SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { - if sre == syserror.ERESTART_RESTARTBLOCK { + if sre, ok := linuxerr.SyscallRestartErrorFromReturn(t.Arch().Return()); ok { + if sre == linuxerr.ERESTART_RESTARTBLOCK { t.Debugf("Restarting syscall %d with restart block after errno %d: not interrupted by handled signal", t.Arch().SyscallNo(), sre) t.Arch().RestartSyscallWithRestartBlock() } else { diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 7065ac79c..eeb3c5e69 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -161,7 +160,7 @@ func (t *Task) deliverSignal(info *linux.SignalInfo, act linux.SigAction) taskRu sigact := computeAction(sig, act) if t.haveSyscallReturn { - if sre, ok := syserror.SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { + if sre, ok := linuxerr.SyscallRestartErrorFromReturn(t.Arch().Return()); ok { // Signals that are ignored, cause a thread group stop, or // terminate the thread group do not interact with interrupted // syscalls; in Linux terms, they are never returned to the signal @@ -170,13 +169,13 @@ func (t *Task) deliverSignal(info *linux.SignalInfo, act linux.SigAction) taskRu // signal that is actually handled (by userspace). if sigact == SignalActionHandler { switch { - case sre == syserror.ERESTARTNOHAND: + case sre == linuxerr.ERESTARTNOHAND: fallthrough - case sre == syserror.ERESTART_RESTARTBLOCK: + case sre == linuxerr.ERESTART_RESTARTBLOCK: fallthrough - case (sre == syserror.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0): + case (sre == linuxerr.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0): t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) - t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1))) + t.Arch().SetReturn(uintptr(-ExtractErrno(linuxerr.EINTR, -1))) default: t.Debugf("Restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) t.Arch().RestartSyscall() diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 0565059c1..4919dea7c 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/sched" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // TaskConfig defines the configuration of a new Task (see below). @@ -141,7 +140,6 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { allowedCPUMask: cfg.AllowedCPUMask.Copy(), ioUsage: &usage.IO{}, niceness: cfg.Niceness, - netns: cfg.NetworkNamespace, utsns: cfg.UTSNamespace, ipcns: cfg.IPCNamespace, abstractSockets: cfg.AbstractSocketNamespace, @@ -153,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { containerID: cfg.ContainerID, cgroups: make(map[Cgroup]struct{}), } + t.netns.Store(cfg.NetworkNamespace) t.creds.Store(cfg.Credentials) t.endStopCond.L = &t.tg.signalHandlers.mu t.ptraceTracer.Store((*Task)(nil)) @@ -170,7 +169,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { // doesn't matter too much since the caller will exit before it returns // to userspace. If the caller isn't in the same thread group, then // we're in uncharted territory and can return whatever we want. - return nil, syserror.EINTR + return nil, linuxerr.EINTR } if err := ts.assignTIDsLocked(t); err != nil { return nil, err @@ -268,7 +267,7 @@ func (ns *PIDNamespace) allocateTID() (ThreadID, error) { // fail with the error ENOMEM; it is not possible to create a new // processes [sic] in a PID namespace whose init process has // terminated." - pid_namespaces(7) - return 0, syserror.ENOMEM + return 0, linuxerr.ENOMEM } tid := ns.last for { diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 0586c9def..2b1d7e114 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/syserror" ) // SyscallRestartBlock represents the restart block for a syscall restartable @@ -383,8 +382,6 @@ func ExtractErrno(err error, sysno int) int { return int(err) case *errors.Error: return int(err.Errno()) - case syserror.SyscallRestartErrno: - return int(err) case *memmap.BusError: // Bus errors may generate SIGBUS, but for syscalls they still // return EFAULT. See case in task_run.go where the fault is @@ -397,8 +394,8 @@ func ExtractErrno(err error, sysno int) int { case *os.SyscallError: return ExtractErrno(err.Err, sysno) default: - if errno, ok := syserror.TranslateError(err); ok { - return int(errno) + if errno, ok := linuxerr.TranslateError(err); ok { + return int(errno.Errno()) } } panic(fmt.Sprintf("Unknown syscall %d error: %v", sysno, err)) diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index 8e2c36598..bff226a11 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -105,7 +104,7 @@ func (t *Task) CopyInVector(addr hostarch.Addr, maxElemSize, maxTotalSize int) ( // Each string has a zero terminating byte counted, so copying out a string // requires at least one byte of space. Also, see the calculation below. if maxTotalSize <= 0 { - return nil, syserror.ENOMEM + return nil, linuxerr.ENOMEM } thisMax := maxElemSize if maxTotalSize < thisMax { @@ -148,7 +147,7 @@ func (t *Task) CopyOutIovecs(addr hostarch.Addr, src hostarch.AddrRangeSeq) erro } default: - return syserror.ENOSYS + return linuxerr.ENOSYS } return nil @@ -220,7 +219,7 @@ func (t *Task) CopyInIovecs(addr hostarch.Addr, numIovecs int) (hostarch.AddrRan } default: - return hostarch.AddrRangeSeq{}, syserror.ENOSYS + return hostarch.AddrRangeSeq{}, linuxerr.ENOSYS } // Truncate to MAX_RW_COUNT. diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 2eda15303..5814a4eca 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -489,11 +489,6 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) tg.signalHandlers.mu.Lock() defer tg.signalHandlers.mu.Unlock() - // TODO(gvisor.dev/issue/6148): "If tcsetpgrp() is called by a member of a - // background process group in its session, and the calling process is not - // blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all members of - // this background process group." - // tty must be the controlling terminal. if tg.tty != tty { return -1, linuxerr.ENOTTY @@ -516,6 +511,16 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) return -1, linuxerr.EPERM } + signalAction := tg.signalHandlers.actions[linux.SIGTTOU] + // If the calling process is a member of a background group, a SIGTTOU + // signal is sent to all members of this background process group. + // We need also need to check whether it is ignoring or blocking SIGTTOU. + ignored := signalAction.Handler == linux.SIG_IGN + blocked := tg.leader.signalMask == linux.SignalSetOf(linux.SIGTTOU) + if tg.processGroup.id != tg.processGroup.session.foreground.id && !ignored && !blocked { + tg.leader.sendSignalLocked(SignalInfoPriv(linux.SIGTTOU), true) + } + tg.processGroup.session.foreground.id = pgid return 0, nil } diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 54bfed644..560a0f33c 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -37,7 +37,6 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/vfs", "//pkg/syserr", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 577374fa4..fb213d109 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -32,7 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -116,7 +115,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { log.Infof("Error reading ELF ident: %v", err) // The entire ident array always exists. if err == io.EOF || err == io.ErrUnexpectedEOF { - err = syserror.ENOEXEC + err = linuxerr.ENOEXEC } return elfInfo{}, err } @@ -124,22 +123,22 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // Only some callers pre-check the ELF magic. if !bytes.Equal(ident[:len(elfMagic)], []byte(elfMagic)) { log.Infof("File is not an ELF") - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } // We only support 64-bit, little endian binaries if class := elf.Class(ident[elf.EI_CLASS]); class != elf.ELFCLASS64 { log.Infof("Unsupported ELF class: %v", class) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } if endian := elf.Data(ident[elf.EI_DATA]); endian != elf.ELFDATA2LSB { log.Infof("Unsupported ELF endianness: %v", endian) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT { log.Infof("Unsupported ELF version: %v", version) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } // EI_OSABI is ignored by Linux, which is the only OS supported. os := abi.Linux @@ -151,7 +150,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { log.Infof("Error reading ELF header: %v", err) // The entire header always exists. if err == io.EOF || err == io.ErrUnexpectedEOF { - err = syserror.ENOEXEC + err = linuxerr.ENOEXEC } return elfInfo{}, err } @@ -166,7 +165,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { a = arch.ARM64 default: log.Infof("Unsupported ELF machine %d", machine) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } var sharedObject bool @@ -178,25 +177,25 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { sharedObject = true default: log.Infof("Unsupported ELF type %v", elfType) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } if int(hdr.Phentsize) != prog64Size { log.Infof("Unsupported phdr size %d", hdr.Phentsize) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } totalPhdrSize := prog64Size * int(hdr.Phnum) if totalPhdrSize < prog64Size { log.Warningf("No phdrs or total phdr size overflows: prog64Size: %d phnum: %d", prog64Size, int(hdr.Phnum)) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } if totalPhdrSize > maxTotalPhdrSize { log.Infof("Too many phdrs (%d): total size %d > %d", hdr.Phnum, totalPhdrSize, maxTotalPhdrSize) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } if int64(hdr.Phoff) < 0 || int64(hdr.Phoff+uint64(totalPhdrSize)) < 0 { ctx.Infof("Unsupported phdr offset %d", hdr.Phoff) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } phdrBuf := make([]byte, totalPhdrSize) @@ -205,7 +204,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { log.Infof("Error reading ELF phdrs: %v", err) // If phdrs were specified, they should all exist. if err == io.EOF || err == io.ErrUnexpectedEOF { - err = syserror.ENOEXEC + err = linuxerr.ENOEXEC } return elfInfo{}, err } @@ -248,19 +247,19 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr if !ok { // If offset != 0 we should have ensured this would fit. ctx.Warningf("Computed segment load address overflows: %#x + %#x", phdr.Vaddr, offset) - return syserror.ENOEXEC + return linuxerr.ENOEXEC } addr -= hostarch.Addr(adjust) fileSize := phdr.Filesz + adjust if fileSize < phdr.Filesz { ctx.Infof("Computed segment file size overflows: %#x + %#x", phdr.Filesz, adjust) - return syserror.ENOEXEC + return linuxerr.ENOEXEC } ms, ok := hostarch.Addr(fileSize).RoundUp() if !ok { ctx.Infof("fileSize %#x too large", fileSize) - return syserror.ENOEXEC + return linuxerr.ENOEXEC } mapSize := uint64(ms) @@ -321,7 +320,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr memSize := phdr.Memsz + adjust if memSize < phdr.Memsz { ctx.Infof("Computed segment mem size overflows: %#x + %#x", phdr.Memsz, adjust) - return syserror.ENOEXEC + return linuxerr.ENOEXEC } // Allocate more anonymous pages if necessary. @@ -333,7 +332,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr anonSize, ok := hostarch.Addr(memSize - mapSize).RoundUp() if !ok { ctx.Infof("extra anon pages too large: %#x", memSize-mapSize) - return syserror.ENOEXEC + return linuxerr.ENOEXEC } // N.B. Linux uses vm_brk_flags to map these pages, which only @@ -423,27 +422,27 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in // NOTE(b/37474556): Linux allows out-of-order // segments, in violation of the spec. ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end) - return loadedELF{}, syserror.ENOEXEC + return loadedELF{}, linuxerr.ENOEXEC } var ok bool end, ok = vaddr.AddLength(phdr.Memsz) if !ok { ctx.Infof("PT_LOAD header size overflows. %#x + %#x", vaddr, phdr.Memsz) - return loadedELF{}, syserror.ENOEXEC + return loadedELF{}, linuxerr.ENOEXEC } case elf.PT_INTERP: if phdr.Filesz < 2 { ctx.Infof("PT_INTERP path too small: %v", phdr.Filesz) - return loadedELF{}, syserror.ENOEXEC + return loadedELF{}, linuxerr.ENOEXEC } if phdr.Filesz > linux.PATH_MAX { ctx.Infof("PT_INTERP path too big: %v", phdr.Filesz) - return loadedELF{}, syserror.ENOEXEC + return loadedELF{}, linuxerr.ENOEXEC } if int64(phdr.Off) < 0 || int64(phdr.Off+phdr.Filesz) < 0 { ctx.Infof("Unsupported PT_INTERP offset %d", phdr.Off) - return loadedELF{}, syserror.ENOEXEC + return loadedELF{}, linuxerr.ENOEXEC } path := make([]byte, phdr.Filesz) @@ -451,12 +450,12 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in if err != nil { // If an interpreter was specified, it should exist. ctx.Infof("Error reading PT_INTERP path: %v", err) - return loadedELF{}, syserror.ENOEXEC + return loadedELF{}, linuxerr.ENOEXEC } if path[len(path)-1] != 0 { ctx.Infof("PT_INTERP path not NUL-terminated: %v", path) - return loadedELF{}, syserror.ENOEXEC + return loadedELF{}, linuxerr.ENOEXEC } // Strip NUL-terminator and everything beyond from @@ -498,7 +497,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in totalSize, ok := totalSize.RoundUp() if !ok { ctx.Infof("ELF PT_LOAD segments too big") - return loadedELF{}, syserror.ENOEXEC + return loadedELF{}, linuxerr.ENOEXEC } var err error @@ -592,7 +591,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS // Check Image Compatibility. if arch.Host != info.arch { ctx.Warningf("Found mismatch for platform %s with ELF type %s", arch.Host.String(), info.arch.String()) - return loadedELF{}, nil, syserror.ENOEXEC + return loadedELF{}, nil, linuxerr.ENOEXEC } // Create the arch.Context now so we can prepare the mmap layout before @@ -681,7 +680,7 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error if interp.interpreter != "" { // No recursive interpreters! ctx.Infof("Interpreter requires an interpreter") - return loadedELF{}, nil, syserror.ENOEXEC + return loadedELF{}, nil, linuxerr.ENOEXEC } } diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go index 3e302d92c..1ec0d7019 100644 --- a/pkg/sentry/loader/interpreter.go +++ b/pkg/sentry/loader/interpreter.go @@ -19,8 +19,8 @@ import ( "io" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsbridge" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -43,14 +43,14 @@ func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.Fil // Short read is OK. if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { - err = syserror.ENOEXEC + err = linuxerr.ENOEXEC } return "", []string{}, err } line = line[:n] if !bytes.Equal(line[:2], []byte(interpreterScriptMagic)) { - return "", []string{}, syserror.ENOEXEC + return "", []string{}, linuxerr.ENOEXEC } // Ignore #!. line = line[2:] @@ -82,7 +82,7 @@ func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.Fil if string(interp) == "" { ctx.Infof("Interpreter script contains no interpreter: %v", line) - return "", []string{}, syserror.ENOEXEC + return "", []string{}, linuxerr.ENOEXEC } // Build the new argument list: diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 86d0c54cd..2759ef71e 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -35,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -91,7 +90,7 @@ type LoadArgs struct { func openPath(ctx context.Context, args LoadArgs) (fsbridge.File, error) { if args.Filename == "" { ctx.Infof("cannot open empty name") - return nil, syserror.ENOENT + return nil, linuxerr.ENOENT } // TODO(gvisor.dev/issue/160): Linux requires only execute permission, @@ -172,7 +171,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context // (e.g., #!a). if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { - err = syserror.ENOEXEC + err = linuxerr.ENOEXEC } return loadedELF{}, nil, nil, nil, err } @@ -190,7 +189,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)): if args.CloseOnExec { - return loadedELF{}, nil, nil, nil, syserror.ENOENT + return loadedELF{}, nil, nil, nil, linuxerr.ENOENT } args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv) if err != nil { @@ -202,7 +201,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context default: ctx.Infof("Unknown magic: %v", hdr) - return loadedELF{}, nil, nil, nil, syserror.ENOEXEC + return loadedELF{}, nil, nil, nil, linuxerr.ENOEXEC } // Set to nil in case we loop on a Interpreter Script. args.File = nil @@ -296,15 +295,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V m.SetEnvvEnd(sl.EnvvEnd) m.SetAuxv(auxv) m.SetExecutable(ctx, file) - - symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn") - if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to find rt_sigreturn in vdso: %v", err), syserr.FromError(err).ToLinux()) - } - - // Found rt_sigretrun. - addr := uint64(vdsoAddr) + symbolValue - vdsoPrelink - m.SetVDSOSigReturn(addr) + m.SetVDSOSigReturn(uint64(vdsoAddr) + vdsoSigreturnOffset - vdsoPrelink) ac.SetIP(uintptr(loaded.entry)) ac.SetStack(uintptr(stack.Bottom)) diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index 054ef1723..bcee6aef6 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -19,7 +19,6 @@ import ( "debug/elf" "fmt" "io" - "strings" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/context" @@ -34,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -102,14 +100,14 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro first = &info.phdrs[i] if phdr.Off != 0 { log.Warningf("First PT_LOAD segment has non-zero file offset") - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } } memoryOffset := phdr.Vaddr - first.Vaddr if memoryOffset != phdr.Off { log.Warningf("PT_LOAD segment memory offset %#x != file offset %#x", memoryOffset, phdr.Off) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } // memsz larger than filesz means that extra zeroed space should be @@ -118,24 +116,24 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro // zeroes. if phdr.Memsz != phdr.Filesz { log.Warningf("PT_LOAD segment memsz %#x != filesz %#x", phdr.Memsz, phdr.Filesz) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } start := hostarch.Addr(memoryOffset) end, ok := start.AddLength(phdr.Memsz) if !ok { log.Warningf("PT_LOAD segment size overflows: %#x + %#x", start, end) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } if uint64(end) > size { log.Warningf("PT_LOAD segment end %#x extends beyond end of file %#x", end, size) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } if prev != nil { if start < prevEnd { log.Warningf("PT_LOAD segments out of order") - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } // We mprotect entire pages, so each segment must be in @@ -144,7 +142,7 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro startPage := start.RoundDown() if prevEndPage >= startPage { log.Warningf("PT_LOAD segments share a page: %#x", prevEndPage) - return elfInfo{}, syserror.ENOEXEC + return elfInfo{}, linuxerr.ENOEXEC } } prev = &info.phdrs[i] @@ -178,27 +176,6 @@ type VDSO struct { phdrs []elf.ProgHeader `state:".([]elfProgHeader)"` } -// getSymbolValueFromVDSO returns the specific symbol value in vdso.so. -func getSymbolValueFromVDSO(symbol string) (uint64, error) { - f, err := elf.NewFile(bytes.NewReader(vdsodata.Binary)) - if err != nil { - return 0, err - } - syms, err := f.Symbols() - if err != nil { - return 0, err - } - - for _, sym := range syms { - if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF { - if strings.Contains(sym.Name, symbol) { - return sym.Value, nil - } - } - } - return 0, fmt.Errorf("no %v in vdso.so", symbol) -} - // PrepareVDSO validates the system VDSO and returns a VDSO, containing the // param page for updating by the kernel. func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { @@ -271,11 +248,11 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) (hostarch.Addr, error) { if v.os != bin.os { ctx.Warningf("Binary ELF OS %v and VDSO ELF OS %v differ", bin.os, v.os) - return 0, syserror.ENOEXEC + return 0, linuxerr.ENOEXEC } if v.arch != bin.arch { ctx.Warningf("Binary ELF arch %v and VDSO ELF arch %v differ", bin.arch, v.arch) - return 0, syserror.ENOEXEC + return 0, linuxerr.ENOEXEC } // Reserve address space for the VDSO and its parameter page, which is @@ -348,35 +325,35 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) segAddr, ok := vdsoAddr.AddLength(memoryOffset) if !ok { ctx.Warningf("PT_LOAD segment address overflows: %#x + %#x", segAddr, memoryOffset) - return 0, syserror.ENOEXEC + return 0, linuxerr.ENOEXEC } segPage := segAddr.RoundDown() segSize := hostarch.Addr(phdr.Memsz) segSize, ok = segSize.AddLength(segAddr.PageOffset()) if !ok { ctx.Warningf("PT_LOAD segment memsize %#x + offset %#x overflows", phdr.Memsz, segAddr.PageOffset()) - return 0, syserror.ENOEXEC + return 0, linuxerr.ENOEXEC } segSize, ok = segSize.RoundUp() if !ok { ctx.Warningf("PT_LOAD segment size overflows: %#x", phdr.Memsz+segAddr.PageOffset()) - return 0, syserror.ENOEXEC + return 0, linuxerr.ENOEXEC } segEnd, ok := segPage.AddLength(uint64(segSize)) if !ok { ctx.Warningf("PT_LOAD segment range overflows: %#x + %#x", segAddr, segSize) - return 0, syserror.ENOEXEC + return 0, linuxerr.ENOEXEC } if segEnd > vdsoEnd { ctx.Warningf("PT_LOAD segment ends beyond VDSO: %#x > %#x", segEnd, vdsoEnd) - return 0, syserror.ENOEXEC + return 0, linuxerr.ENOEXEC } perms := progFlagsAsPerms(phdr.Flags) if perms != hostarch.Read { if err := m.MProtect(segPage, uint64(segSize), perms, false); err != nil { ctx.Warningf("Unable to set PT_LOAD segment protections %+v at [%#x, %#x): %v", perms, segAddr, segEnd, err) - return 0, syserror.ENOEXEC + return 0, linuxerr.ENOEXEC } } } @@ -389,3 +366,21 @@ func (v *VDSO) Release(ctx context.Context) { v.ParamPage.DecRef(ctx) v.vdso.DecRef(ctx) } + +var vdsoSigreturnOffset = func() uint64 { + f, err := elf.NewFile(bytes.NewReader(vdsodata.Binary)) + if err != nil { + panic(fmt.Sprintf("failed to parse vdso.so as ELF file: %v", err)) + } + syms, err := f.Symbols() + if err != nil { + panic(fmt.Sprintf("failed to read symbols from vdso.so: %v", err)) + } + const sigreturnSymbol = "__kernel_rt_sigreturn" + for _, sym := range syms { + if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF && sym.Name == sigreturnSymbol { + return sym.Value + } + } + panic(fmt.Sprintf("no symbol %q in vdso.so", sigreturnSymbol)) +}() diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index c30e88725..a89bfa680 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -54,7 +54,6 @@ go_library( "//pkg/hostarch", "//pkg/log", "//pkg/safemem", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 69aff21b6..b7d782b7f 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -144,7 +144,6 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sync", - "//pkg/syserror", "//pkg/tcpip/buffer", "//pkg/usermem", ], diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index b7f765cd7..d71d64580 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -77,15 +77,6 @@ func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) return nil } - // Only unmaps after it assured that the address is a valid aio context to - // prevent random memory from been unmapped. - // - // Note: It's possible to unmap this address and map something else into - // the same address. Then it would be unmapping memory that it doesn't own. - // This is, however, the way Linux implements AIO. Keeps the same [weird] - // semantics in case anyone relies on it. - mm.MUnmap(ctx, hostarch.Addr(id), aioRingBufferSize) - delete(mm.aioManager.contexts, id) aioCtx.destroy() return aioCtx @@ -411,6 +402,15 @@ func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOC return nil } + // Only unmaps after it assured that the address is a valid aio context to + // prevent random memory from been unmapped. + // + // Note: It's possible to unmap this address and map something else into + // the same address. Then it would be unmapping memory that it doesn't own. + // This is, however, the way Linux implements AIO. Keeps the same [weird] + // semantics in case anyone relies on it. + mm.MUnmap(ctx, hostarch.Addr(id), aioRingBufferSize) + mm.aioManager.mu.Lock() defer mm.aioManager.mu.Unlock() return mm.destroyAIOContextLocked(ctx, id) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 57969b26c..0fca59b64 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -28,6 +28,7 @@ // memmap.File locks // mm.aioManager.mu // mm.AIOContext.mu +// kernel.TaskSet.mu // // Only mm.MemoryManager.Fork is permitted to lock mm.MemoryManager.activeMu in // multiple mm.MemoryManagers, as it does so in a well-defined order (forked diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 9f4cc238f..05cdcd8ae 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -324,20 +324,37 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter panic(fmt.Sprintf("pma %v needs to be copied for writing, but is not readable: %v", pseg.Range(), oldpma)) } } - // The majority of copy-on-write breaks on executable pages - // come from: - // - // - The ELF loader, which must zero out bytes on the last - // page of each segment after the end of the segment. - // - // - gdb's use of ptrace to insert breakpoints. - // - // Neither of these cases has enough spatial locality to - // benefit from copying nearby pages, so if the vma is - // executable, only copy the pages required. var copyAR hostarch.AddrRange - if vseg.ValuePtr().effectivePerms.Execute { + if vma := vseg.ValuePtr(); vma.effectivePerms.Execute { + // The majority of copy-on-write breaks on executable + // pages come from: + // + // - The ELF loader, which must zero out bytes on the + // last page of each segment after the end of the + // segment. + // + // - gdb's use of ptrace to insert breakpoints. + // + // Neither of these cases has enough spatial locality + // to benefit from copying nearby pages, so if the vma + // is executable, only copy the pages required. copyAR = pseg.Range().Intersect(ar) + } else if vma.growsDown { + // In most cases, the new process will not use most of + // its stack before exiting or invoking execve(); it is + // especially unlikely to return very far down its call + // stack, since async-signal-safety concerns in + // multithreaded programs prevent the new process from + // being able to do much. So only copy up to one page + // before and after the pages required. + stackMaskAR := ar + if newStart := stackMaskAR.Start - hostarch.PageSize; newStart < stackMaskAR.Start { + stackMaskAR.Start = newStart + } + if newEnd := stackMaskAR.End + hostarch.PageSize; newEnd > stackMaskAR.End { + stackMaskAR.End = newEnd + } + copyAR = pseg.Range().Intersect(stackMaskAR) } else { copyAR = pseg.Range().Intersect(maskAR) } diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 256eb4afb..dc12ad357 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/syserror" ) // HandleUserFault handles an application page fault. sp is the faulting @@ -79,7 +78,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar } length, ok := hostarch.Addr(opts.Length).RoundUp() if !ok { - return 0, syserror.ENOMEM + return 0, linuxerr.ENOMEM } opts.Length = uint64(length) @@ -90,7 +89,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar } // Offset + length must not overflow. if end := opts.Offset + opts.Length; end < opts.Offset { - return 0, syserror.ENOMEM + return 0, linuxerr.EOVERFLOW } } else { opts.Offset = 0 @@ -253,7 +252,7 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (hostarch.AddrRange, erro ctx.Warningf("Capping stack size from RLIMIT_STACK of %v down to %v.", sz, maxStackSize) sz = maxStackSize } else if sz == 0 { - return hostarch.AddrRange{}, syserror.ENOMEM + return hostarch.AddrRange{}, linuxerr.ENOMEM } szaddr := hostarch.Addr(sz) ctx.Debugf("Allocating stack with size of %v bytes", sz) @@ -262,7 +261,7 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (hostarch.AddrRange, erro // randomization can't be disabled. stackEnd := mm.layout.MaxAddr - hostarch.Addr(mrand.Int63n(int64(mm.layout.MaxStackRand))).RoundDown() if stackEnd < szaddr { - return hostarch.AddrRange{}, syserror.ENOMEM + return hostarch.AddrRange{}, linuxerr.ENOMEM } stackStart := stackEnd - szaddr mm.mappingMu.Lock() @@ -500,7 +499,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS // Check against RLIMIT_AS. newUsageAS := mm.usageAS - uint64(oldAR.Length()) + uint64(newAR.Length()) if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS { - return 0, syserror.ENOMEM + return 0, linuxerr.ENOMEM } if vma := vseg.ValuePtr(); vma.mappable != nil { @@ -599,11 +598,11 @@ func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms h } rlength, ok := hostarch.Addr(length).RoundUp() if !ok { - return syserror.ENOMEM + return linuxerr.ENOMEM } ar, ok := addr.ToRange(uint64(rlength)) if !ok { - return syserror.ENOMEM + return linuxerr.ENOMEM } effectivePerms := realPerms.Effective() @@ -616,19 +615,19 @@ func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms h // the non-growsDown case. vseg := mm.vmas.LowerBoundSegment(ar.Start) if !vseg.Ok() { - return syserror.ENOMEM + return linuxerr.ENOMEM } if growsDown { if !vseg.ValuePtr().growsDown { return linuxerr.EINVAL } if ar.End <= vseg.Start() { - return syserror.ENOMEM + return linuxerr.ENOMEM } ar.Start = vseg.Start() } else { if ar.Start < vseg.Start() { - return syserror.ENOMEM + return linuxerr.ENOMEM } } @@ -688,7 +687,7 @@ func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms h } vseg, _ = vseg.NextNonEmpty() if !vseg.Ok() { - return syserror.ENOMEM + return linuxerr.ENOMEM } } } @@ -724,7 +723,7 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr hostarch.Addr) (hostarch. if uint64(addr-mm.brk.Start) > limits.FromContext(ctx).Get(limits.Data).Cur { addr = mm.brk.End mm.mappingMu.Unlock() - return addr, syserror.ENOMEM + return addr, linuxerr.ENOMEM } oldbrkpg, _ := mm.brk.End.RoundUp() @@ -798,7 +797,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u } if newLockedAS := mm.lockedAS + uint64(ar.Length()) - mm.mlockedBytesRangeLocked(ar); newLockedAS > mlockLimit { mm.mappingMu.Unlock() - return syserror.ENOMEM + return linuxerr.ENOMEM } } } @@ -835,7 +834,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u mm.vmas.MergeAdjacent(ar) if unmapped { mm.mappingMu.Unlock() - return syserror.ENOMEM + return linuxerr.ENOMEM } if mode == memmap.MLockEager { @@ -850,7 +849,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u // case, which is converted to ENOMEM by mlock. mm.activeMu.Unlock() mm.mappingMu.RUnlock() - return syserror.ENOMEM + return linuxerr.ENOMEM } _, _, err := mm.getPMAsLocked(ctx, vseg, vseg.Range().Intersect(ar), hostarch.NoAccess) if err != nil { @@ -858,7 +857,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u mm.mappingMu.RUnlock() // Linux: mm/mlock.c:__mlock_posix_error_return() if linuxerr.Equals(linuxerr.EFAULT, err) { - return syserror.ENOMEM + return linuxerr.ENOMEM } if linuxerr.Equals(linuxerr.ENOMEM, err) { return linuxerr.EAGAIN @@ -917,7 +916,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error } if uint64(mm.vmas.Span()) > mlockLimit { mm.mappingMu.Unlock() - return syserror.ENOMEM + return linuxerr.ENOMEM } } } @@ -1040,7 +1039,7 @@ func (mm *MemoryManager) SetDontFork(addr hostarch.Addr, length uint64, dontfork } if mm.vmas.SpanRange(ar) != ar.Length() { - return syserror.ENOMEM + return linuxerr.ENOMEM } return nil } @@ -1099,7 +1098,7 @@ func (mm *MemoryManager) Decommit(addr hostarch.Addr, length uint64) error { // to the rest (but returns ENOMEM from the system call, as it should)." - // madvise(2) if mm.vmas.SpanRange(ar) != ar.Length() { - return syserror.ENOMEM + return linuxerr.ENOMEM } return nil } @@ -1123,11 +1122,11 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length u } la, ok := hostarch.Addr(length).RoundUp() if !ok { - return syserror.ENOMEM + return linuxerr.ENOMEM } ar, ok := addr.ToRange(uint64(la)) if !ok { - return syserror.ENOMEM + return linuxerr.ENOMEM } mm.mappingMu.RLock() @@ -1135,7 +1134,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length u vseg := mm.vmas.LowerBoundSegment(ar.Start) if !vseg.Ok() { mm.mappingMu.RUnlock() - return syserror.ENOMEM + return linuxerr.ENOMEM } var unmapped bool lastEnd := ar.Start @@ -1184,7 +1183,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length u } if unmapped { - return syserror.ENOMEM + return linuxerr.ENOMEM } return nil } diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 5f8ab7ca3..e34b7a2f7 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/syserror" ) // Preconditions: @@ -59,7 +58,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp newUsageAS -= uint64(mm.vmas.SpanRange(ar)) } if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS { - return vmaIterator{}, hostarch.AddrRange{}, syserror.ENOMEM + return vmaIterator{}, hostarch.AddrRange{}, linuxerr.ENOMEM } if opts.MLockMode != memmap.MLockNone { @@ -178,7 +177,7 @@ func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOp // Fixed mappings accept only the requested address. if opts.Fixed { - return 0, syserror.ENOMEM + return 0, linuxerr.ENOMEM } // Prefer hugepage alignment if a hugepage or more is requested. @@ -216,7 +215,7 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou return gr.Start, nil } } - return 0, syserror.ENOMEM + return 0, linuxerr.ENOMEM } // Preconditions: mm.mappingMu must be locked. @@ -236,7 +235,7 @@ func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bo return start, nil } } - return 0, syserror.ENOMEM + return 0, linuxerr.ENOMEM } // Preconditions: mm.mappingMu must be locked. diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index d351869ef..496a9fd97 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -97,7 +97,6 @@ go_library( "//pkg/state", "//pkg/state/wire", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 0c8542485..68e17d343 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -39,7 +39,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // MemoryFile is a memmap.File whose pages may be allocated to arbitrary @@ -404,7 +403,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File // Find a range in the underlying file. fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment) if !ok { - return memmap.FileRange{}, syserror.ENOMEM + return memmap.FileRange{}, linuxerr.ENOMEM } // Expand the file if needed. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 8a490b3de..834d72408 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,13 +1,26 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "atomicptr_machine", + out = "atomicptr_machine_unsafe.go", + package = "kvm", + prefix = "machine", + template = "//pkg/sync/atomicptr:generic_atomicptr", + types = { + "Value": "machine", + }, +) + go_library( name = "kvm", srcs = [ "address_space.go", "address_space_amd64.go", "address_space_arm64.go", + "atomicptr_machine_unsafe.go", "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", @@ -50,7 +63,6 @@ go_library( "//pkg/procid", "//pkg/ring0", "//pkg/ring0/pagetables", - "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", @@ -58,6 +70,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/time", + "//pkg/sighandling", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], @@ -69,10 +82,17 @@ go_test( "kvm_amd64_test.go", "kvm_amd64_test.s", "kvm_arm64_test.go", + "kvm_safecopy_test.go", "kvm_test.go", "virtual_map_test.go", ], library = ":kvm", + # FIXME(gvisor.dev/issue/3374): Not working with all build systems. + nogo = False, + # cgo has to be disabled. We have seen libc that blocks all signals and + # calls mmap from pthread_create, but we use SIGSYS to trap mmap system + # calls. + pure = True, tags = [ "manual", "nogotsan", @@ -81,8 +101,10 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/hostarch", + "//pkg/memutil", "//pkg/ring0", "//pkg/ring0/pagetables", + "//pkg/safecopy", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", "//pkg/sentry/platform", diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index bb9967b9f..5be2215ed 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -19,8 +19,8 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/ring0" - "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sighandling" ) // bluepill enters guest mode. @@ -61,6 +61,9 @@ var ( // This is called by bluepillHandler. savedHandler uintptr + // savedSigsysHandler is a pointer to the previos handler of the SIGSYS signals. + savedSigsysHandler uintptr + // dieTrampolineAddr is the address of dieTrampoline. dieTrampolineAddr uintptr ) @@ -94,7 +97,7 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 0567c8d32..b2db2bb9f 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -71,10 +71,6 @@ func (c *vCPU) KernelSyscall() { if regs.Rax != ^uint64(0) { regs.Rip -= 2 // Rewind. } - // We only trigger a bluepill entry in the bluepill function, and can - // therefore be guaranteed that there is no floating point state to be - // loaded on resuming from halt. We only worry about saving on exit. - ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no. // N.B. Since KernelSyscall is called when the kernel makes a syscall, // FS_BASE is already set for correct execution of this function. // @@ -112,8 +108,6 @@ func (c *vCPU) KernelException(vector ring0.Vector) { regs.Rip = 0 } // See above. - ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no. - // See above. ring0.HaltAndWriteFSBase(regs) // escapes: no, reload host segment. } @@ -144,5 +138,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { // Set the context pointer to the saved floating point state. This is // where the guest data has been serialized, the kernel will restore // from this new pointer value. - context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer())) + context.Fpstate = uint64(uintptrValue(c.FloatingPointState().BytePointer())) // escapes: no. } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index c2a1dca11..5d8358f64 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -32,6 +32,8 @@ // This is checked as the source of the fault. #define CLI $0xfa +#define SYS_MMAP 9 + // See bluepill.go. TEXT ·bluepill(SB),NOSPLIT,$0 begin: @@ -95,6 +97,31 @@ TEXT ·addrOfSighandler(SB), $0-8 MOVQ AX, ret+0(FP) RET +TEXT ·sigsysHandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel. + MOVQ $1, CX + CMPL CX, 0x8(SI) + JNE fallback + + MOVL CONTEXT_RAX(DX), CX + CMPL CX, $SYS_MMAP + JNE fallback + PUSHQ DX // First argument (context). + CALL ·seccompMmapHandler(SB) // Call the handler. + POPQ DX // Discard the argument. + RET +fallback: + // Jump to the previous signal handler. + XORQ CX, CX + MOVQ ·savedSigsysHandler(SB), AX + JMP AX + +// func addrOfSighandler() uintptr +TEXT ·addrOfSigsysHandler(SB), $0-8 + MOVQ $·sigsysHandler(SB), AX + MOVQ AX, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 PUSHQ BX // First argument (vCPU). diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index acb0cb05f..df772d620 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -70,7 +70,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { lazyVfp := c.GetLazyVFP() if lazyVfp != 0 { - fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) + fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no context.Fpsimd64.Fpsr = fpsimd.Fpsr context.Fpsimd64.Fpcr = fpsimd.Fpcr context.Fpsimd64.Vregs = fpsimd.Vregs @@ -90,12 +90,12 @@ func (c *vCPU) KernelSyscall() { fpDisableTrap := ring0.CPACREL1() if fpDisableTrap != 0 { - fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) + fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no fpcr := ring0.GetFPCR() fpsr := ring0.GetFPSR() fpsimd.Fpcr = uint32(fpcr) fpsimd.Fpsr = uint32(fpsr) - ring0.SaveVRegs(c.floatingPointState.BytePointer()) + ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no } ring0.Halt() @@ -114,12 +114,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) { fpDisableTrap := ring0.CPACREL1() if fpDisableTrap != 0 { - fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) + fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no fpcr := ring0.GetFPCR() fpsr := ring0.GetFPSR() fpsimd.Fpcr = uint32(fpcr) fpsimd.Fpsr = uint32(fpsr) - ring0.SaveVRegs(c.floatingPointState.BytePointer()) + ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no } ring0.Halt() diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 308f2a951..9690e3772 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -29,9 +29,12 @@ // Only limited use of the context is done in the assembly stub below, most is // done in the Go handlers. #define SIGINFO_SIGNO 0x0 +#define SIGINFO_CODE 0x8 #define CONTEXT_PC 0x1B8 #define CONTEXT_R0 0xB8 +#define SYS_MMAP 222 + // getTLS returns the value of TPIDR_EL0 register. TEXT ·getTLS(SB),NOSPLIT,$0-8 MRS TPIDR_EL0, R1 @@ -98,6 +101,37 @@ TEXT ·addrOfSighandler(SB), $0-8 MOVD R0, ret+0(FP) RET +// The arguments are the following: +// +// R0 - The signal number. +// R1 - Pointer to siginfo_t structure. +// R2 - Pointer to ucontext structure. +// +TEXT ·sigsysHandler(SB),NOSPLIT,$0 + // si_code should be SYS_SECCOMP. + MOVD SIGINFO_CODE(R1), R7 + CMPW $1, R7 + BNE fallback + + CMPW $SYS_MMAP, R8 + BNE fallback + + MOVD R2, 8(RSP) + BL ·seccompMmapHandler(SB) // Call the handler. + + RET + +fallback: + // Jump to the previous signal handler. + MOVD ·savedHandler(SB), R7 + B (R7) + +// func addrOfSighandler() uintptr +TEXT ·addrOfSigsysHandler(SB), $0-8 + MOVD $·sigsysHandler(SB), R0 + MOVD R0, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 // R0: Fake the old PC as caller diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index 8fd8287b3..7a3c97c5a 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -55,11 +55,7 @@ func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virt } // Adjust the block to match our size. - physicalStart = alignedPhysical & faultBlockMask - if physicalStart < pr.physical { - // Bound the starting point to the start of the region. - physicalStart = pr.physical - } + physicalStart = pr.physical + (alignedPhysical-pr.physical)&faultBlockMask virtualStart = pr.virtual + (physicalStart - pr.physical) physicalEnd := physicalStart + faultBlockSize if physicalEnd > end { diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 0f0c1e73b..e38ca05c0 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -193,36 +193,8 @@ func bluepillHandler(context unsafe.Pointer) { return } - // Increment the fault count. - atomic.AddUint32(&c.faults, 1) - - // For MMIO, the physical address is the first data item. - physical = uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) - if !ok { - c.die(bluepillArchContext(context), "invalid physical address") - return - } - - // We now need to fill in the data appropriately. KVM - // expects us to provide the result of the given MMIO - // operation in the runData struct. This is safe - // because, if a fault occurs here, the same fault - // would have occurred in guest mode. The kernel should - // not create invalid page table mappings. - data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1])) - length := (uintptr)((uint32)(c.runData.data[2])) - write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0 - for i := uintptr(0); i < length; i++ { - b := bytePtr(uintptr(virtual) + i) - if write { - // Write to the given address. - *b = data[i] - } else { - // Read from the given address. - data[i] = *b - } - } + c.die(bluepillArchContext(context), "exit_mmio") + return case _KVM_EXIT_IRQ_WINDOW_OPEN: bluepillStopGuest(c) case _KVM_EXIT_SHUTDOWN: diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index aac0fdffe..ad6863646 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -77,7 +77,11 @@ var ( // OpenDevice opens the KVM device at /dev/kvm and returns the File. func OpenDevice() (*os.File, error) { - f, err := os.OpenFile("/dev/kvm", unix.O_RDWR, 0) + dev, ok := os.LookupEnv("GVISOR_KVM_DEV") + if !ok { + dev = "/dev/kvm" + } + f, err := os.OpenFile(dev, unix.O_RDWR, 0) if err != nil { return nil, fmt.Errorf("error opening /dev/kvm: %v", err) } diff --git a/pkg/sentry/platform/kvm/kvm_safecopy_test.go b/pkg/sentry/platform/kvm/kvm_safecopy_test.go new file mode 100644 index 000000000..9a87c9e6f --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_safecopy_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// FIXME(gvisor.dev/issue//6629): These tests don't pass on ARM64. +// +//go:build amd64 +// +build amd64 + +package kvm + +import ( + "fmt" + "os" + "testing" + "unsafe" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/safecopy" +) + +func testSafecopy(t *testing.T, mapSize uintptr, fileSize uintptr, testFunc func(t *testing.T, c *vCPU, addr uintptr)) { + memfd, err := memutil.CreateMemFD(fmt.Sprintf("kvm_test_%d", os.Getpid()), 0) + if err != nil { + t.Errorf("error creating memfd: %v", err) + } + + memfile := os.NewFile(uintptr(memfd), "kvm_test") + memfile.Truncate(int64(fileSize)) + kvmTest(t, nil, func(c *vCPU) bool { + const n = 10 + mappings := make([]uintptr, n) + defer func() { + for i := 0; i < n && mappings[i] != 0; i++ { + unix.RawSyscall( + unix.SYS_MUNMAP, + mappings[i], mapSize, 0) + } + }() + for i := 0; i < n; i++ { + addr, _, errno := unix.RawSyscall6( + unix.SYS_MMAP, + 0, + mapSize, + unix.PROT_READ|unix.PROT_WRITE, + unix.MAP_SHARED|unix.MAP_FILE, + uintptr(memfile.Fd()), + 0) + if errno != 0 { + t.Errorf("error mapping file: %v", errno) + } + mappings[i] = addr + testFunc(t, c, addr) + } + return false + }) +} + +func TestSafecopySigbus(t *testing.T) { + mapSize := uintptr(faultBlockSize) + fileSize := mapSize - hostarch.PageSize + buf := make([]byte, hostarch.PageSize) + testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) { + want := safecopy.BusError{addr + fileSize} + bluepill(c) + _, err := safecopy.CopyIn(buf, unsafe.Pointer(addr+fileSize)) + if err != want { + t.Errorf("expected error: got %v, want %v", err, want) + } + }) +} + +func TestSafecopy(t *testing.T) { + mapSize := uintptr(faultBlockSize) + fileSize := mapSize + testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) { + want := uint32(0x12345678) + bluepill(c) + _, err := safecopy.SwapUint32(unsafe.Pointer(addr+fileSize-8), want) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + bluepill(c) + val, err := safecopy.LoadUint32(unsafe.Pointer(addr + fileSize - 8)) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if val != want { + t.Errorf("incorrect value: got %x, want %x", val, want) + } + }) +} diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index e7092a756..f1f7e4ea4 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -17,16 +17,20 @@ package kvm import ( "fmt" "runtime" + gosync "sync" "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" + "gvisor.dev/gvisor/pkg/seccomp" ktime "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sighandling" "gvisor.dev/gvisor/pkg/sync" ) @@ -35,6 +39,9 @@ type machine struct { // fd is the vm fd. fd int + // machinePoolIndex is the index in the machinePool array. + machinePoolIndex uint32 + // nextSlot is the next slot for setMemoryRegion. // // This must be accessed atomically. If nextSlot is ^uint32(0), then @@ -192,6 +199,10 @@ func (m *machine) newVCPU() *vCPU { return c // Done. } +// readOnlyGuestRegions contains regions that have to be mapped read-only into +// the guest physical address space. Right now, it is used on arm64 only. +var readOnlyGuestRegions []region + // newMachine returns a new VM context. func newMachine(vm int) (*machine, error) { // Create the machine. @@ -227,6 +238,10 @@ func newMachine(vm int) (*machine, error) { m.upperSharedPageTables.MarkReadOnlyShared() m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress) + // Install seccomp rules to trap runtime mmap system calls. They will + // be handled by seccompMmapHandler. + seccompMmapRules(m) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -241,32 +256,11 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) - var physicalRegionsReadOnly []physicalRegion - var physicalRegionsAvailable []physicalRegion - - physicalRegionsReadOnly = rdonlyRegionsForSetMem() - physicalRegionsAvailable = availableRegionsForSetMem() - - // Map all read-only regions. - for _, r := range physicalRegionsReadOnly { - m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) - } - // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to // ensure successful vCPU entry. - applyVirtualRegions(func(vr virtualRegion) { - if excludeVirtualRegion(vr) { - return // skip region. - } - - for _, r := range physicalRegionsReadOnly { - if vr.virtual == r.virtual { - return - } - } - + mapRegion := func(vr region, flags uint32) { for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -280,9 +274,32 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) + m.mapPhysical(physical, length, physicalRegions, flags) virtual += length } + } + + for _, vr := range readOnlyGuestRegions { + mapRegion(vr, _KVM_MEM_READONLY) + } + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + for _, r := range readOnlyGuestRegions { + if vr.virtual == r.virtual { + return + } + } + // Take into account that the stack can grow down. + if vr.filename == "[stack]" { + vr.virtual -= 1 << 20 + vr.length += 1 << 20 + } + + mapRegion(vr.region, 0) + }) // Initialize architecture state. @@ -352,6 +369,10 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg func (m *machine) Destroy() { runtime.SetFinalizer(m, nil) + machinePoolMu.Lock() + machinePool[m.machinePoolIndex].Store(nil) + machinePoolMu.Unlock() + // Destroy vCPUs. for _, c := range m.vCPUsByID { if c == nil { @@ -519,15 +540,21 @@ func (c *vCPU) lock() { // //go:nosplit func (c *vCPU) unlock() { - if atomic.CompareAndSwapUint32(&c.state, vCPUUser|vCPUGuest, vCPUGuest) { + origState := atomicbitops.CompareAndSwapUint32(&c.state, vCPUUser|vCPUGuest, vCPUGuest) + if origState == vCPUUser|vCPUGuest { // Happy path: no exits are forced, and we can continue // executing on our merry way with a single atomic access. return } // Clear the lock. - origState := atomic.LoadUint32(&c.state) - atomicbitops.AndUint32(&c.state, ^vCPUUser) + for { + state := atomicbitops.CompareAndSwapUint32(&c.state, origState, origState&^vCPUUser) + if state == origState { + break + } + origState = state + } switch origState { case vCPUUser: // Normal state. @@ -677,3 +704,72 @@ func (c *vCPU) setSystemTimeLegacy() error { } } } + +const machinePoolSize = 16 + +// machinePool is enumerated from the seccompMmapHandler signal handler +var ( + machinePool [machinePoolSize]machineAtomicPtr + machinePoolLen uint32 + machinePoolMu sync.Mutex + seccompMmapRulesOnce gosync.Once +) + +func sigsysHandler() +func addrOfSigsysHandler() uintptr + +// seccompMmapRules adds seccomp rules to trap mmap system calls that will be +// handled in seccompMmapHandler. +func seccompMmapRules(m *machine) { + seccompMmapRulesOnce.Do(func() { + // Install the handler. + if err := sighandling.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) + } + rules := []seccomp.RuleSet{} + rules = append(rules, []seccomp.RuleSet{ + // Trap mmap system calls and handle them in sigsysGoHandler + { + Rules: seccomp.SyscallRules{ + unix.SYS_MMAP: { + { + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + /* MAP_DENYWRITE is ignored and used only for filtering. */ + seccomp.MaskedEqual(unix.MAP_DENYWRITE, 0), + }, + }, + }, + Action: linux.SECCOMP_RET_TRAP, + }, + }...) + instrs, err := seccomp.BuildProgram(rules, linux.SECCOMP_RET_ALLOW, linux.SECCOMP_RET_ALLOW) + if err != nil { + panic(fmt.Sprintf("failed to build rules: %v", err)) + } + // Perform the actual installation. + if err := seccomp.SetFilter(instrs); err != nil { + panic(fmt.Sprintf("failed to set filter: %v", err)) + } + }) + + machinePoolMu.Lock() + n := atomic.LoadUint32(&machinePoolLen) + i := uint32(0) + for ; i < n; i++ { + if machinePool[i].Load() == nil { + break + } + } + if i == n { + if i == machinePoolSize { + machinePoolMu.Unlock() + panic("machinePool is full") + } + atomic.AddUint32(&machinePoolLen, 1) + } + machinePool[i].Store(m) + m.machinePoolIndex = i + machinePoolMu.Unlock() +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index a96634381..5bc023899 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" ) @@ -72,10 +71,6 @@ type vCPUArchState struct { // // This starts above fixedKernelPCID. PCIDs *pagetables.PCIDs - - // floatingPointState is the floating point state buffer used in guest - // to host transitions. See usage in bluepill_amd64.go. - floatingPointState fpu.State } const ( @@ -152,12 +147,6 @@ func (c *vCPU) initArchState() error { return fmt.Errorf("error setting user registers: %v", errno) } - // Allocate some floating point state save area for the local vCPU. - // This will be saved prior to leaving the guest, and we restore from - // this always. We cannot use the pointer in the context alone because - // we don't know how large the area there is in reality. - c.floatingPointState = fpu.NewState() - // Set the time offset to the host native time. return c.setSystemTime() } @@ -309,22 +298,6 @@ func loadByte(ptr *byte) byte { return *ptr } -// prefaultFloatingPointState touches each page of the floating point state to -// be sure that its physical pages are mapped. -// -// Otherwise the kernel can trigger KVM_EXIT_MMIO and an instruction that -// triggered a fault will be emulated by the kvm kernel code, but it can't -// emulate instructions like xsave and xrstor. -// -//go:nosplit -func prefaultFloatingPointState(data *fpu.State) { - size := len(*data) - for i := 0; i < size; i += hostarch.PageSize { - loadByte(&(*data)[i]) - } - loadByte(&(*data)[size-1]) -} - // SwitchToUser unpacks architectural-details. func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. @@ -355,11 +328,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) // allocations occur. entersyscall() bluepill(c) - // The root table physical page has to be mapped to not fault in iret - // or sysret after switching into a user address space. sysret and - // iret are in the upper half that is global and already mapped. - switchOpts.PageTables.PrefaultRootTable() - prefaultFloatingPointState(switchOpts.FloatingPointState) vector = c.CPU.SwitchToUser(switchOpts) exitsyscall() @@ -522,3 +490,7 @@ func (m *machine) getNewVCPU() *vCPU { } return nil } + +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index de798bb2c..fbacea9ad 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -161,3 +161,15 @@ func (c *vCPU) getSystemRegisters(sregs *systemRegs) unix.Errno { } return 0 } + +//go:nosplit +func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) { + ctx := bluepillArchContext(context) + + // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters. + addr, _, e := unix.RawSyscall6(uintptr(ctx.Rax), uintptr(ctx.Rdi), uintptr(ctx.Rsi), + uintptr(ctx.Rdx), uintptr(ctx.R10)|unix.MAP_DENYWRITE, uintptr(ctx.R8), uintptr(ctx.R9)) + ctx.Rax = uint64(addr) + + return addr, uintptr(ctx.Rsi), e +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 7937a8481..31998a600 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ) @@ -40,10 +39,6 @@ type vCPUArchState struct { // // This starts above fixedKernelPCID. PCIDs *pagetables.PCIDs - - // floatingPointState is the floating point state buffer used in guest - // to host transitions. See usage in bluepill_arm64.go. - floatingPointState fpu.State } const ( @@ -110,18 +105,128 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { return phyRegions } +// archPhysicalRegions fills readOnlyGuestRegions and allocates separate +// physical regions form them. +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + if !vr.accessType.Write { + readOnlyGuestRegions = append(readOnlyGuestRegions, vr.region) + } + }) + + rdRegions := readOnlyGuestRegions[:] + + // Add an unreachable region. + rdRegions = append(rdRegions, region{ + virtual: 0xffffffffffffffff, + length: 0, + }) + + var regions []physicalRegion + addValidRegion := func(r *physicalRegion, virtual, length uintptr) { + if length == 0 { + return + } + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), + }) + } + i := 0 + for _, pr := range physicalRegions { + start := pr.virtual + end := pr.virtual + pr.length + for start < end { + rdRegion := rdRegions[i] + rdStart := rdRegion.virtual + rdEnd := rdRegion.virtual + rdRegion.length + if rdEnd <= start { + i++ + continue + } + if rdStart > start { + newEnd := rdStart + if end < rdStart { + newEnd = end + } + addValidRegion(&pr, start, newEnd-start) + start = rdStart + continue + } + if rdEnd < end { + addValidRegion(&pr, start, rdEnd-start) + start = rdEnd + continue + } + addValidRegion(&pr, start, end-start) + start = end + } + } + + return regions +} + // Get all available physicalRegions. -func availableRegionsForSetMem() (phyRegions []physicalRegion) { - var excludeRegions []region +func availableRegionsForSetMem() []physicalRegion { + var excludedRegions []region applyVirtualRegions(func(vr virtualRegion) { if !vr.accessType.Write { - excludeRegions = append(excludeRegions, vr.region) + excludedRegions = append(excludedRegions, vr.region) } }) - phyRegions = computePhysicalRegions(excludeRegions) + // Add an unreachable region. + excludedRegions = append(excludedRegions, region{ + virtual: 0xffffffffffffffff, + length: 0, + }) - return phyRegions + var regions []physicalRegion + addValidRegion := func(r *physicalRegion, virtual, length uintptr) { + if length == 0 { + return + } + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), + }) + } + i := 0 + for _, pr := range physicalRegions { + start := pr.virtual + end := pr.virtual + pr.length + for start < end { + er := excludedRegions[i] + excludeEnd := er.virtual + er.length + excludeStart := er.virtual + if excludeEnd < start { + i++ + continue + } + if excludeStart < start { + start = excludeEnd + i++ + continue + } + rend := excludeStart + if rend > end { + rend = end + } + addValidRegion(&pr, start, rend-start) + start = excludeEnd + } + } + + return regions } // nonCanonical generates a canonical address return. diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 1a4a9ce7d..e73d5c544 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" ) @@ -159,8 +158,6 @@ func (c *vCPU) initArchState() error { c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs) } - c.floatingPointState = fpu.NewState() - return c.setSystemTime() } @@ -333,3 +330,15 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) } } + +//go:nosplit +func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) { + ctx := bluepillArchContext(context) + + // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters. + addr, _, e := unix.RawSyscall6(uintptr(ctx.Regs[8]), uintptr(ctx.Regs[0]), uintptr(ctx.Regs[1]), + uintptr(ctx.Regs[2]), uintptr(ctx.Regs[3])|unix.MAP_DENYWRITE, uintptr(ctx.Regs[4]), uintptr(ctx.Regs[5])) + ctx.Regs[0] = uint64(addr) + + return addr, uintptr(ctx.Regs[1]), e +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index cc3a1253b..cf3a4e7c9 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -171,3 +171,46 @@ func (c *vCPU) setSignalMask() error { return nil } + +// seccompMmapHandler is a signal handler for runtime mmap system calls +// that are trapped by seccomp. +// +// It executes the mmap syscall with specified arguments and maps a new region +// to the guest. +// +//go:nosplit +func seccompMmapHandler(context unsafe.Pointer) { + addr, length, errno := seccompMmapSyscall(context) + if errno != 0 { + return + } + + for i := uint32(0); i < atomic.LoadUint32(&machinePoolLen); i++ { + m := machinePool[i].Load() + if m == nil { + continue + } + + // Map the new region to the guest. + vr := region{ + virtual: addr, + length: length, + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { + physical, length, ok := translateToPhysical(virtual) + if !ok { + // This must be an invalid region that was + // knocked out by creation of the physical map. + return + } + if virtual+length > vr.virtual+vr.length { + // Cap the length to the end of the area. + length = vr.virtual + vr.length - virtual + } + + // Ensure the physical range is mapped. + m.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) + virtual += length + } + } +} diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index d812e6c26..9864d1258 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -168,6 +168,9 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica } addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd) + // Do arch-specific actions on physical regions. + physicalRegions = archPhysicalRegions(physicalRegions) + // Dump our all physical regions. for _, r := range physicalRegions { log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)", diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go index 6d0ba8252..346a10043 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -30,8 +30,8 @@ import ( func TLSWorks() bool // SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *arch.Registers, fn func()) { - regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) +func SetTestTarget(regs *arch.Registers, fn uintptr) { + regs.Pc = uint64(fn) } // SetTouchTarget sets rax appropriately. diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 7348c29a5..42876245a 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -28,6 +28,11 @@ TEXT ·Getpid(SB),NOSPLIT,$0 SVC RET +TEXT ·AddrOfGetpid(SB),NOSPLIT,$0-8 + MOVD $·Getpid(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·Touch(SB),NOSPLIT,$0 start: MOVD 0(R8), R1 @@ -35,21 +40,41 @@ start: SVC B start +TEXT ·AddrOfTouch(SB),NOSPLIT,$0-8 + MOVD $·Touch(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·HaltLoop(SB),NOSPLIT,$0 start: HLT B start +TEXT ·AddOfHaltLoop(SB),NOSPLIT,$0-8 + MOVD $·HaltLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + // This function simulates a loop of syscall. TEXT ·SyscallLoop(SB),NOSPLIT,$0 start: SVC B start +TEXT ·AddrOfSyscallLoop(SB),NOSPLIT,$0-8 + MOVD $·SyscallLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·SpinLoop(SB),NOSPLIT,$0 start: B start +TEXT ·AddrOfSpinLoop(SB),NOSPLIT,$0-8 + MOVD $·SpinLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·TLSWorks(SB),NOSPLIT,$0-8 NO_LOCAL_POINTERS MOVD $0x6789, R5 @@ -125,6 +150,11 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 SVC RET // never reached +TEXT ·AddrOfTwiddleRegsSyscall(SB),NOSPLIT,$0-8 + MOVD $·TwiddleRegsSyscall(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 TWIDDLE_REGS() MSR R10, TPIDR_EL0 @@ -132,3 +162,8 @@ TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 // Branch to Register branches unconditionally to an address in <Rn>. JMP (R6) // <=> br x6, must fault RET // never reached + +TEXT ·AddrOfTwiddleRegsFault(SB),NOSPLIT,$0-8 + MOVD $·TwiddleRegsFault(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD new file mode 100644 index 000000000..35feb969f --- /dev/null +++ b/pkg/sentry/seccheck/BUILD @@ -0,0 +1,58 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_fieldenum:defs.bzl", "go_fieldenum") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +licenses(["notice"]) + +go_fieldenum( + name = "seccheck_fieldenum", + srcs = [ + "clone.go", + "execve.go", + "exit.go", + "task.go", + ], + out = "seccheck_fieldenum.go", + package = "seccheck", +) + +go_template_instance( + name = "seqatomic_checkerslice", + out = "seqatomic_checkerslice_unsafe.go", + package = "seccheck", + suffix = "CheckerSlice", + template = "//pkg/sync/seqatomic:generic_seqatomic", + types = { + "Value": "[]Checker", + }, +) + +go_library( + name = "seccheck", + srcs = [ + "clone.go", + "execve.go", + "exit.go", + "seccheck.go", + "seccheck_fieldenum.go", + "seqatomic_checkerslice_unsafe.go", + "task.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/gohacks", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", + "//pkg/sync", + ], +) + +go_test( + name = "seccheck_test", + size = "small", + srcs = ["seccheck_test.go"], + library = ":seccheck", + deps = ["//pkg/context"], +) diff --git a/pkg/sentry/seccheck/clone.go b/pkg/sentry/seccheck/clone.go new file mode 100644 index 000000000..7546fa021 --- /dev/null +++ b/pkg/sentry/seccheck/clone.go @@ -0,0 +1,53 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// CloneInfo contains information used by the Clone checkpoint. +// +// +fieldenum Clone +type CloneInfo struct { + // Invoker identifies the invoking thread. + Invoker TaskInfo + + // Credentials are the invoking thread's credentials. + Credentials *auth.Credentials + + // Args contains the arguments to kernel.Task.Clone(). + Args linux.CloneArgs + + // Created identifies the created thread. + Created TaskInfo +} + +// CloneReq returns fields required by the Clone checkpoint. +func (s *state) CloneReq() CloneFieldSet { + return s.cloneReq.Load() +} + +// Clone is called at the Clone checkpoint. +func (s *state) Clone(ctx context.Context, mask CloneFieldSet, info *CloneInfo) error { + for _, c := range s.getCheckers() { + if err := c.Clone(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/execve.go b/pkg/sentry/seccheck/execve.go new file mode 100644 index 000000000..f36e0730e --- /dev/null +++ b/pkg/sentry/seccheck/execve.go @@ -0,0 +1,65 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// ExecveInfo contains information used by the Execve checkpoint. +// +// +fieldenum Execve +type ExecveInfo struct { + // Invoker identifies the invoking thread. + Invoker TaskInfo + + // Credentials are the invoking thread's credentials. + Credentials *auth.Credentials + + // BinaryPath is a path to the executable binary file being switched to in + // the mount namespace in which it was opened. + BinaryPath string + + // Argv is the new process image's argument vector. + Argv []string + + // Env is the new process image's environment variables. + Env []string + + // BinaryMode is the executable binary file's mode. + BinaryMode uint16 + + // BinarySHA256 is the SHA-256 hash of the executable binary file. + // + // Note that this requires reading the entire file into memory, which is + // likely to be extremely slow. + BinarySHA256 [32]byte +} + +// ExecveReq returns fields required by the Execve checkpoint. +func (s *state) ExecveReq() ExecveFieldSet { + return s.execveReq.Load() +} + +// Execve is called at the Execve checkpoint. +func (s *state) Execve(ctx context.Context, mask ExecveFieldSet, info *ExecveInfo) error { + for _, c := range s.getCheckers() { + if err := c.Execve(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/exit.go b/pkg/sentry/seccheck/exit.go new file mode 100644 index 000000000..69cb6911c --- /dev/null +++ b/pkg/sentry/seccheck/exit.go @@ -0,0 +1,57 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" +) + +// ExitNotifyParentInfo contains information used by the ExitNotifyParent +// checkpoint. +// +// +fieldenum ExitNotifyParent +type ExitNotifyParentInfo struct { + // Exiter identifies the exiting thread. Note that by the checkpoint's + // definition, Exiter.ThreadID == Exiter.ThreadGroupID and + // Exiter.ThreadStartTime == Exiter.ThreadGroupStartTime, so requesting + // ThreadGroup* fields is redundant. + Exiter TaskInfo + + // ExitStatus is the exiting thread group's exit status, as reported + // by wait*(). + ExitStatus linux.WaitStatus +} + +// ExitNotifyParentReq returns fields required by the ExitNotifyParent +// checkpoint. +func (s *state) ExitNotifyParentReq() ExitNotifyParentFieldSet { + return s.exitNotifyParentReq.Load() +} + +// ExitNotifyParent is called at the ExitNotifyParent checkpoint. +// +// The ExitNotifyParent checkpoint occurs when a zombied thread group leader, +// not waiting for exit acknowledgement from a non-parent ptracer, becomes the +// last non-dead thread in its thread group and notifies its parent of its +// exiting. +func (s *state) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info *ExitNotifyParentInfo) error { + for _, c := range s.getCheckers() { + if err := c.ExitNotifyParent(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go new file mode 100644 index 000000000..e13274096 --- /dev/null +++ b/pkg/sentry/seccheck/seccheck.go @@ -0,0 +1,158 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package seccheck defines a structure for dynamically-configured security +// checks in the sentry. +package seccheck + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sync" +) + +// A Point represents a checkpoint, a point at which a security check occurs. +type Point uint + +// PointX represents the checkpoint X. +const ( + PointClone Point = iota + PointExecve + PointExitNotifyParent + // Add new Points above this line. + pointLength + + numPointBitmaskUint32s = (int(pointLength)-1)/32 + 1 +) + +// A Checker performs security checks at checkpoints. +// +// Each Checker method X is called at checkpoint X; if the method may return a +// non-nil error and does so, it causes the checked operation to fail +// immediately (without calling subsequent Checkers) and return the error. The +// info argument contains information relevant to the check. The mask argument +// indicates what fields in info are valid; the mask should usually be a +// superset of fields requested by the Checker's corresponding CheckerReq, but +// may be missing requested fields in some cases (e.g. if the Checker is +// registered concurrently with invocations of checkpoints). +type Checker interface { + Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error + Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error + ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error +} + +// CheckerDefaults may be embedded by implementations of Checker to obtain +// no-op implementations of Checker methods that may be explicitly overridden. +type CheckerDefaults struct{} + +// Clone implements Checker.Clone. +func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + return nil +} + +// Execve implements Checker.Execve. +func (CheckerDefaults) Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error { + return nil +} + +// ExitNotifyParent implements Checker.ExitNotifyParent. +func (CheckerDefaults) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error { + return nil +} + +// CheckerReq indicates what checkpoints a corresponding Checker runs at, and +// what information it requires at those checkpoints. +type CheckerReq struct { + // Points are the set of checkpoints for which the corresponding Checker + // must be called. Note that methods not specified in Points may still be + // called; implementations of Checker may embed CheckerDefaults to obtain + // no-op implementations of Checker methods. + Points []Point + + // All of the following fields indicate what fields in the corresponding + // XInfo struct will be requested at the corresponding checkpoint. + Clone CloneFields + Execve ExecveFields + ExitNotifyParent ExitNotifyParentFields +} + +// Global is the method receiver of all seccheck functions. +var Global state + +// state is the type of global, and is separated out for testing. +type state struct { + // registrationMu serializes all changes to the set of registered Checkers + // for all checkpoints. + registrationMu sync.Mutex + + // enabledPoints is a bitmask of checkpoints for which at least one Checker + // is registered. + // + // enabledPoints is accessed using atomic memory operations. Mutation of + // enabledPoints is serialized by registrationMu. + enabledPoints [numPointBitmaskUint32s]uint32 + + // registrationSeq supports store-free atomic reads of registeredCheckers. + registrationSeq sync.SeqCount + + // checkers is the set of all registered Checkers in order of execution. + // + // checkers is accessed using instantiations of SeqAtomic functions. + // Mutation of checkers is serialized by registrationMu. + checkers []Checker + + // All of the following xReq variables indicate what fields in the + // corresponding XInfo struct have been requested by any registered + // checker, are accessed using atomic memory operations, and are mutated + // with registrationMu locked. + cloneReq CloneFieldSet + execveReq ExecveFieldSet + exitNotifyParentReq ExitNotifyParentFieldSet +} + +// AppendChecker registers the given Checker to execute at checkpoints. The +// Checker will execute after all previously-registered Checkers, and only if +// those Checkers return a nil error. +func (s *state) AppendChecker(c Checker, req *CheckerReq) { + s.registrationMu.Lock() + defer s.registrationMu.Unlock() + + s.cloneReq.AddFieldsLoadable(req.Clone) + s.execveReq.AddFieldsLoadable(req.Execve) + s.exitNotifyParentReq.AddFieldsLoadable(req.ExitNotifyParent) + + s.appendCheckerLocked(c) + for _, p := range req.Points { + word, bit := p/32, p%32 + atomic.StoreUint32(&s.enabledPoints[word], s.enabledPoints[word]|(uint32(1)<<bit)) + } +} + +// Enabled returns true if any Checker is registered for the given checkpoint. +func (s *state) Enabled(p Point) bool { + word, bit := p/32, p%32 + return atomic.LoadUint32(&s.enabledPoints[word])&(uint32(1)<<bit) != 0 +} + +func (s *state) getCheckers() []Checker { + return SeqAtomicLoadCheckerSlice(&s.registrationSeq, &s.checkers) +} + +// Preconditions: s.registrationMu must be locked. +func (s *state) appendCheckerLocked(c Checker) { + s.registrationSeq.BeginWrite() + s.checkers = append(s.checkers, c) + s.registrationSeq.EndWrite() +} diff --git a/pkg/sentry/seccheck/seccheck_test.go b/pkg/sentry/seccheck/seccheck_test.go new file mode 100644 index 000000000..687810d18 --- /dev/null +++ b/pkg/sentry/seccheck/seccheck_test.go @@ -0,0 +1,157 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "errors" + "testing" + + "gvisor.dev/gvisor/pkg/context" +) + +type testChecker struct { + CheckerDefaults + + onClone func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error +} + +// Clone implements Checker.Clone. +func (c *testChecker) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + if c.onClone == nil { + return nil + } + return c.onClone(ctx, mask, info) +} + +func TestNoChecker(t *testing.T) { + var s state + if s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got true, wanted false") + } +} + +func TestCheckerNotRegisteredForPoint(t *testing.T) { + var s state + s.AppendChecker(&testChecker{}, &CheckerReq{}) + if s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got true, wanted false") + } +} + +func TestCheckerRegistered(t *testing.T) { + var s state + checkerCalled := false + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkerCalled = true + return nil + }}, &CheckerReq{ + Points: []Point{PointClone}, + Clone: CloneFields{ + Credentials: true, + }, + }) + + if !s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got false, wanted true") + } + if !s.CloneReq().Contains(CloneFieldCredentials) { + t.Errorf("CloneReq().Contains(CloneFieldCredentials): got false, wanted true") + } + if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil { + t.Errorf("Clone(): got %v, wanted nil", err) + } + if !checkerCalled { + t.Errorf("Clone() did not call Checker.Clone()") + } +} + +func TestMultipleCheckersRegistered(t *testing.T) { + var s state + checkersCalled := [2]bool{} + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkersCalled[0] = true + return nil + }}, &CheckerReq{ + Points: []Point{PointClone}, + Clone: CloneFields{ + Args: true, + }, + }) + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkersCalled[1] = true + return nil + }}, &CheckerReq{ + Points: []Point{PointClone}, + Clone: CloneFields{ + Created: TaskFields{ + ThreadID: true, + }, + }, + }) + + if !s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got false, wanted true") + } + // CloneReq() should return the union of requested fields from all calls to + // AppendChecker. + req := s.CloneReq() + if !req.Contains(CloneFieldArgs) { + t.Errorf("req.Contains(CloneFieldArgs): got false, wanted true") + } + if !req.Created.Contains(TaskFieldThreadID) { + t.Errorf("req.Created.Contains(TaskFieldThreadID): got false, wanted true") + } + if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil { + t.Errorf("Clone(): got %v, wanted nil", err) + } + for i := range checkersCalled { + if !checkersCalled[i] { + t.Errorf("Clone() did not call Checker.Clone() index %d", i) + } + } +} + +func TestCheckpointReturnsFirstCheckerError(t *testing.T) { + errFirstChecker := errors.New("first Checker error") + errSecondChecker := errors.New("second Checker error") + + var s state + checkersCalled := [2]bool{} + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkersCalled[0] = true + return errFirstChecker + }}, &CheckerReq{ + Points: []Point{PointClone}, + }) + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkersCalled[1] = true + return errSecondChecker + }}, &CheckerReq{ + Points: []Point{PointClone}, + }) + + if !s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got false, wanted true") + } + if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != errFirstChecker { + t.Errorf("Clone(): got %v, wanted %v", err, errFirstChecker) + } + if !checkersCalled[0] { + t.Errorf("Clone() did not call first Checker") + } + if checkersCalled[1] { + t.Errorf("Clone() called second Checker") + } +} diff --git a/pkg/sentry/seccheck/task.go b/pkg/sentry/seccheck/task.go new file mode 100644 index 000000000..1dee33203 --- /dev/null +++ b/pkg/sentry/seccheck/task.go @@ -0,0 +1,39 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" +) + +// TaskInfo contains information unambiguously identifying a single thread +// and/or its containing process. +// +// +fieldenum Task +type TaskInfo struct { + // ThreadID is the thread's ID in the root PID namespace. + ThreadID int32 + + // ThreadStartTime is the thread's CLOCK_REALTIME start time. + ThreadStartTime ktime.Time + + // ThreadGroupID is the thread's group leader's ID in the root PID + // namespace. + ThreadGroupID int32 + + // ThreadGroupStartTime is the thread's group leader's CLOCK_REALTIME start + // time. + ThreadGroupStartTime ktime.Time +} diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 7ee89a735..00f925166 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "socket", - srcs = ["socket.go"], + srcs = [ + "socket.go", + "socket_state.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 00a5e729a..6077b2150 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -29,10 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "time" ) -const maxInt = int(^uint(0) >> 1) - // SCMCredentials represents a SCM_CREDENTIALS socket control message. type SCMCredentials interface { transport.CredentialsControlMessage @@ -78,7 +77,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { } // Files implements SCMRights.Files. -func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) { +func (fs *RightsFiles) Files(_ context.Context, max int) (RightsFiles, bool) { n := max var trunc bool if l := len(*fs); n > l { @@ -124,7 +123,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32 break } - fds = append(fds, int32(fd)) + fds = append(fds, fd) } return fds, trunc } @@ -300,8 +299,8 @@ func alignSlice(buf []byte, align uint) []byte { } // PackTimestamp packs a SO_TIMESTAMP socket control message. -func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { - timestampP := linux.NsecToTimeval(timestamp) +func PackTimestamp(t *kernel.Task, timestamp time.Time, buf []byte) []byte { + timestampP := linux.NsecToTimeval(timestamp.UnixNano()) return putCmsgStruct( buf, linux.SOL_SOCKET, @@ -355,6 +354,17 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketIn ) } +// PackIPv6PacketInfo packs an IPV6_PKTINFO socket control message. +func PackIPv6PacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPv6PacketInfo, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IPV6, + linux.IPV6_PKTINFO, + t.Arch().Width(), + packetInfo, + ) +} + // PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { var level uint32 @@ -412,6 +422,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) } + if cmsgs.IP.HasIPv6PacketInfo { + buf = PackIPv6PacketInfo(t, &cmsgs.IP.IPv6PacketInfo, buf) + } + if cmsgs.IP.OriginalDstAddress != nil { buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) } @@ -453,6 +467,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) } + if cmsgs.IP.HasIPv6PacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPv6PacketInfo) + } + if cmsgs.IP.OriginalDstAddress != nil { space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) } @@ -526,7 +544,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var ts linux.Timeval ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) - cmsgs.IP.Timestamp = ts.ToNsecCapped() + cmsgs.IP.Timestamp = ts.ToTime() cmsgs.IP.HasTimestamp = true i += bits.AlignUp(length, width) diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go index 7e28a0cef..1b04e1bbc 100644 --- a/pkg/sentry/socket/control/control_test.go +++ b/pkg/sentry/socket/control/control_test.go @@ -50,7 +50,7 @@ func TestParse(t *testing.T) { want := socket.ControlMessages{ IP: socket.IPControlMessages{ HasTimestamp: true, - Timestamp: ts.ToNsecCapped(), + Timestamp: ts.ToTime(), }, } if diff := cmp.Diff(want, cmsg); diff != "" { diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 3950caa0f..4ea89f9d0 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -38,7 +38,6 @@ go_library( "//pkg/sentry/socket/control", "//pkg/sentry/vfs", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/stack", "//pkg/usermem", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 38cb2c99c..6e2318f75 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -35,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -112,7 +111,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } return readv(s.fd, safemem.IovecsFromBlockSeq(dsts)) })) - return int64(n), err + return n, err } // Write implements fs.FileOperations.Write. @@ -135,7 +134,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } return writev(s.fd, safemem.IovecsFromBlockSeq(srcs)) })) - return int64(n), err + return n, err } // Socket implements socket.Provider.Socket. @@ -181,7 +180,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } @@ -208,7 +207,7 @@ type socketOpsCommon struct { // Release implements fs.FileOperations.Release. func (s *socketOpsCommon) Release(context.Context) { fdnotifier.RemoveFD(int32(s.fd)) - unix.Close(s.fd) + _ = unix.Close(s.fd) } // Readiness implements waiter.Waitable.Readiness. @@ -219,13 +218,13 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { // EventRegister implements waiter.Waitable.EventRegister. func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(s.fd)) + _ = fdnotifier.UpdateFD(int32(s.fd)) } // EventUnregister implements waiter.Waitable.EventUnregister. func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(s.fd)) + _ = fdnotifier.UpdateFD(int32(s.fd)) } // Connect implements socket.Socket.Connect. @@ -288,7 +287,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC) if blocking { var ch chan struct{} - for syscallErr == syserror.ErrWouldBlock { + for syscallErr == linuxerr.ErrWouldBlock { if ch != nil { if syscallErr = t.Block(ch); syscallErr != nil { break @@ -317,7 +316,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, if kernel.VFS2Enabled { f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&unix.SOCK_NONBLOCK)) if err != nil { - unix.Close(fd) + _ = unix.Close(fd) return 0, nil, 0, err } defer f.DecRef(t) @@ -329,7 +328,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, } else { f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&unix.SOCK_NONBLOCK != 0) if err != nil { - unix.Close(fd) + _ = unix.Close(fd) return 0, nil, 0, err } defer f.DecRef(t) @@ -344,7 +343,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, } // Bind implements socket.Socket.Bind. -func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -357,12 +356,12 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Listen implements socket.Socket.Listen. -func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error { return syserr.FromError(unix.Listen(s.fd, backlog)) } // Shutdown implements socket.Socket.Shutdown. -func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error { switch how { case unix.SHUT_RD, unix.SHUT_WR, unix.SHUT_RDWR: return syserr.FromError(unix.Shutdown(s.fd, how)) @@ -372,7 +371,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, _ hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -402,7 +401,7 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr case linux.TCP_NODELAY: optlen = sizeofInt32 case linux.TCP_INFO: - optlen = int(linux.SizeOfTCPInfo) + optlen = linux.SizeOfTCPInfo } } @@ -535,7 +534,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags n, err := copyToDst() // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. if flags&(unix.MSG_DONTWAIT|unix.MSG_ERRQUEUE) == 0 { - for err == syserror.ErrWouldBlock { + for err == linuxerr.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. if n != 0 { @@ -580,7 +579,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) - controlMessages.IP.Timestamp = ts.ToNsecCapped() + controlMessages.IP.Timestamp = ts.ToTime() } case linux.SOL_IP: @@ -707,7 +706,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var ch chan struct{} n, err := src.CopyInTo(t, sendmsgFromBlocks) if flags&unix.MSG_DONTWAIT == 0 { - for err == syserror.ErrWouldBlock { + for err == linuxerr.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. if n != 0 { @@ -716,7 +715,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if ch != nil { if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -735,7 +734,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b func translateIOSyscallError(err error) error { if err == unix.EAGAIN || err == unix.EWOULDBLOCK { - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } return err } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index ea56f39c1..b9c15daab 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID { } // Action implements stack.Target.Action. -func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index ed85404da..9710a15ee 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -36,7 +36,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 5c3ae26f8..ed5fa9c38 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -39,7 +39,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -530,7 +529,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } - if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if n, err := doRead(); err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC @@ -548,7 +547,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := doRead(); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != linuxerr.ErrWouldBlock { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e828982eb..075f61cda 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "device.go", "netstack.go", + "netstack_state.go", "netstack_vfs2.go", "provider.go", "provider_vfs2.go", @@ -42,13 +43,13 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/header", "//pkg/tcpip/link/tun", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/usermem", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9b844b0c0..030c6c8e4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -56,12 +56,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -275,6 +274,7 @@ var Metrics = tcpip.Stats{ ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."), SegmentsAckedWithDSACK: mustCreateMetric("/netstack/tcp/segments_acked_with_dsack", "Number of segments for which DSACK was received."), + SpuriousRecovery: mustCreateMetric("/netstack/tcp/spurious_recovery", "Number of times the connection entered loss recovery spuriously."), }, UDP: tcpip.UDPStats{ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."), @@ -379,9 +379,9 @@ type socketOpsCommon struct { // timestampValid indicates whether timestamp for SIOCGSTAMP has been // set. It is protected by readMu. timestampValid bool - // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only + // timestamp holds the timestamp to use with SIOCTSTAMP. It is only // valid when timestampValid is true. It is protected by readMu. - timestampNS int64 + timestamp time.Time `state:".(int64)"` // TODO(b/153685824): Move this to SocketOptions. // sockOptInq corresponds to TCP_INQ. @@ -411,13 +411,25 @@ var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes() var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes() var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes() -// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the -// netstack representation taking any addresses into account. -func bytesToIPAddress(addr []byte) tcpip.Address { - if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { - return "" +// minSockAddrLen returns the minimum length in bytes of a socket address for +// the socket's family. +func (s *socketOpsCommon) minSockAddrLen() int { + const addressFamilySize = 2 + + switch s.family { + case linux.AF_UNIX: + return addressFamilySize + case linux.AF_INET: + return sockAddrInetSize + case linux.AF_INET6: + return sockAddrInet6Size + case linux.AF_PACKET: + return sockAddrLinkSize + case linux.AF_UNSPEC: + return addressFamilySize + default: + panic(fmt.Sprintf("s.family unrecognized = %d", s.family)) } - return tcpip.Address(addr) } func (s *socketOpsCommon) isPacketBased() bool { @@ -448,7 +460,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { t := kernel.TaskFromContext(ctx) start := t.Kernel().MonotonicClock().Now() deadline := start.Add(v.Timeout) - t.BlockWithDeadline(ch, true, deadline) + _ = t.BlockWithDeadline(ch, true, deadline) } } @@ -459,7 +471,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) if err == syserr.ErrWouldBlock { - return int64(n), syserror.ErrWouldBlock + return int64(n), linuxerr.ErrWouldBlock } if err != nil { return 0, err.ToError() @@ -468,7 +480,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } // WriteTo implements fs.FileOperations.WriteTo. -func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { +func (s *SocketOperations) WriteTo(_ context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { s.readMu.Lock() defer s.readMu.Unlock() @@ -492,14 +504,14 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO r := src.Reader(ctx) n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if _, ok := err.(*tcpip.ErrWouldBlock); ok { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } if n < src.NumBytes() { - return n, syserror.ErrWouldBlock + return n, linuxerr.ErrWouldBlock } return n, nil @@ -523,7 +535,7 @@ func (l *limitedPayloader) Len() int { } // ReadFrom implements fs.FileOperations.ReadFrom. -func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { +func (s *SocketOperations) ReadFrom(_ context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { f := limitedPayloader{ inner: io.LimitedReader{ R: r, @@ -546,16 +558,21 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return s.Endpoint.Readiness(mask) } -func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { +// checkFamily returns true iff the specified address family may be used with +// the socket. +// +// If exact is true, then the specified address family must be an exact match +// with the socket's family. +func (s *socketOpsCommon) checkFamily(family uint16, exact bool) bool { if family == uint16(s.family) { - return nil + return true } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { if !s.Endpoint.SocketOptions().GetV6Only() { - return nil + return true } } - return syserr.ErrInvalidArgument + return false } // mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the @@ -588,8 +605,8 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool return syserr.TranslateNetstackError(err) } - if err := s.checkFamily(family, false /* exact */); err != nil { - return err + if !s.checkFamily(family, false /* exact */) { + return syserr.ErrInvalidArgument } addr = s.mapFamily(addr, family) @@ -629,7 +646,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < 2 { return syserr.ErrInvalidArgument } @@ -647,23 +664,24 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) - if a.Protocol != uint16(s.protocol) { - return syserr.ErrInvalidArgument - } - addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: socket.Ntohs(a.Protocol), } } else { + if s.minSockAddrLen() > len(sockaddr) { + return syserr.ErrInvalidArgument + } + var err *syserr.Error addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } - if err = s.checkFamily(family, true /* exact */); err != nil { - return err + if !s.checkFamily(family, true /* exact */) { + return syserr.ErrAddressFamilyNotSupported } addr = s.mapFamily(addr, family) @@ -688,7 +706,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Listen implements the linux syscall listen(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error { return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog)) } @@ -779,7 +797,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) { // Shutdown implements the linux syscall shutdown(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error { f, err := ConvertShutdown(how) if err != nil { return err @@ -860,7 +878,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, _ linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1345,6 +1363,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) return &v, nil + case linux.IPV6_RECVPKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6ReceivePacketInfo())) + return &v, nil + case linux.IP6T_ORIGINAL_DST: if outLen < sockAddrInet6Size { return nil, syserr.ErrInvalidArgument @@ -1368,11 +1394,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true) + info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, true) if err != nil { return nil, err } @@ -1388,11 +1414,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen) + entries, err := netfilter.GetEntries6(t, stk.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } @@ -1408,8 +1434,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber) @@ -1425,7 +1451,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { +func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, _ int) (marshal.Marshallable, *syserr.Error) { if _, ok := ep.(tcpip.Endpoint); !ok { log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) return nil, syserr.ErrUnknownProtocolOption @@ -1565,11 +1591,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false) + info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, false) if err != nil { return nil, err } @@ -1585,11 +1611,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen) + entries, err := netfilter.GetEntries4(t, stk.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } @@ -1605,8 +1631,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber) @@ -2046,7 +2072,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { return syserr.ErrInvalidEndpointState - } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + } else if isUDPSocket(skType, skProto) && transport.DatagramEndpointState(ep.State()) != transport.DatagramEndpointStateInitial { return syserr.ErrInvalidEndpointState } @@ -2101,6 +2127,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) return nil + case linux.IPV6_RECVPKTINFO: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(hostarch.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetIPv6ReceivePacketInfo(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2143,12 +2178,12 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true) + return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, true) case linux.IP6T_SO_SET_ADD_COUNTERS: log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported") @@ -2386,12 +2421,12 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false) + return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, false) case linux.IPT_SO_SET_ADD_COUNTERS: log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported") @@ -2490,7 +2525,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, linux.IPV6_RECVPATHMTU, - linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, linux.IPV6_RTHDR, linux.IPV6_RTHDRDSTOPTS, @@ -2559,7 +2593,7 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2571,7 +2605,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2716,6 +2750,8 @@ func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.Contr TClass: readCM.TClass, HasIPPacketInfo: readCM.HasIPPacketInfo, PacketInfo: readCM.PacketInfo, + HasIPv6PacketInfo: readCM.HasIPv6PacketInfo, + IPv6PacketInfo: readCM.IPv6PacketInfo, OriginalDstAddress: readCM.OriginalDstAddress, SockErr: readCM.SockErr, }, @@ -2730,7 +2766,7 @@ func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) { // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. if !s.sockOptTimestamp { s.timestampValid = true - s.timestampNS = cm.Timestamp + s.timestamp = cm.Timestamp } } @@ -2789,7 +2825,7 @@ func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, _ uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { if flags&linux.MSG_ERRQUEUE != 0 { return s.recvErr(t, dst) } @@ -2873,8 +2909,8 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if err != nil { return 0, err } - if err := s.checkFamily(family, false /* exact */); err != nil { - return 0, err + if !s.checkFamily(family, false /* exact */) { + return 0, syserr.ErrInvalidArgument } addrBuf = s.mapFamily(addrBuf, family) @@ -2951,10 +2987,10 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { - return 0, syserror.ENOENT + return 0, linuxerr.ENOENT } - tv := linux.NsecToTimeval(s.timestampNS) + tv := linux.NsecToTimeval(s.timestamp.UnixNano()) _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err @@ -3061,7 +3097,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc } // interfaceIoctl implements interface requests. -func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error { +func interfaceIoctl(ctx context.Context, _ usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error { var ( iface inet.Interface index int32 @@ -3069,8 +3105,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe ) // Find the relevant device. - stack := inet.StackFromContext(ctx) - if stack == nil { + stk := inet.StackFromContext(ctx) + if stk == nil { return syserr.ErrNoDevice } @@ -3080,7 +3116,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4])) - if iface, ok := stack.Interfaces()[index]; ok { + if iface, ok := stk.Interfaces()[index]; ok { ifr.SetName(iface.Name) return nil } @@ -3088,7 +3124,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // Find the relevant device. - for index, iface = range stack.Interfaces() { + for index, iface = range stk.Interfaces() { if iface.Name == ifr.Name() { found = true break @@ -3121,7 +3157,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } case linux.SIOCGIFFLAGS: - f, err := interfaceStatusFlags(stack, iface.Name) + f, err := interfaceStatusFlags(stk, iface.Name) if err != nil { return err } @@ -3131,7 +3167,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case linux.SIOCGIFADDR: // Copy the IPv4 address out. - for _, addr := range stack.InterfaceAddrs()[index] { + for _, addr := range stk.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. if addr.Family != linux.AF_INET { continue @@ -3167,7 +3203,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case linux.SIOCGIFNETMASK: // Gets the network mask of a device. - for _, addr := range stack.InterfaceAddrs()[index] { + for _, addr := range stk.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. if addr.Family != linux.AF_INET { continue @@ -3199,24 +3235,24 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, _ usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. - stack := inet.StackFromContext(ctx) - if stack == nil { + stk := inet.StackFromContext(ctx) + if stk == nil { return syserr.ErrNoDevice.ToError() } if ifc.Ptr == 0 { - ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq) + ifc.Len = int32(len(stk.Interfaces())) * int32(linux.SizeOfIFReq) return nil } max := ifc.Len ifc.Len = 0 - for key, ifaceAddrs := range stack.InterfaceAddrs() { - iface := stack.Interfaces()[key] + for key, ifaceAddrs := range stk.InterfaceAddrs() { + iface := stk.Interfaces()[key] for _, ifaceAddr := range ifaceAddrs { // Don't write past the end of the buffer. if ifc.Len+int32(linux.SizeOfIFReq) > max { @@ -3332,10 +3368,10 @@ func (s *socketOpsCommon) State() uint32 { } case isUDPSocket(s.skType, s.protocol): // UDP socket. - switch udp.EndpointState(s.Endpoint.State()) { - case udp.StateInitial, udp.StateBound, udp.StateClosed: + switch transport.DatagramEndpointState(s.Endpoint.State()) { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateBound, transport.DatagramEndpointStateClosed: return linux.TCP_CLOSE - case udp.StateConnected: + case transport.DatagramEndpointStateConnected: return linux.TCP_ESTABLISHED default: return 0 diff --git a/pkg/sentry/socket/netstack/netstack_state.go b/pkg/sentry/socket/netstack/netstack_state.go new file mode 100644 index 000000000..591e00d42 --- /dev/null +++ b/pkg/sentry/socket/netstack/netstack_state.go @@ -0,0 +1,31 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netstack + +import ( + "time" +) + +func (s *socketOpsCommon) saveTimestamp() int64 { + s.readMu.Lock() + defer s.readMu.Unlock() + return s.timestamp.UnixNano() +} + +func (s *socketOpsCommon) loadTimestamp(nsec int64) { + s.readMu.Lock() + defer s.readMu.Unlock() + s.timestamp = time.Unix(0, nsec) +} diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index edc160b1b..3cdf29b80 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -113,7 +112,7 @@ func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs. } n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) if err == syserr.ErrWouldBlock { - return int64(n), syserror.ErrWouldBlock + return int64(n), linuxerr.ErrWouldBlock } if err != nil { return 0, err.ToError() @@ -132,14 +131,14 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs r := src.Reader(ctx) n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if _, ok := err.(*tcpip.ErrWouldBlock); ok { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } if n < src.NumBytes() { - return n, syserror.ErrWouldBlock + return n, linuxerr.ErrWouldBlock } return n, nil diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 208ab9909..ea199f223 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // Attach address to interface. nicID := tcpip.NICID(idx) - if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + if err := s.Stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { return syserr.TranslateNetstackError(err).ToError() } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 658e90bb9..d4b80a39d 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "sync/atomic" + "time" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" @@ -51,8 +52,19 @@ type ControlMessages struct { func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { var p linux.ControlMessageIPPacketInfo p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + copy(p.LocalAddr[:], packetInfo.LocalAddr) + copy(p.DestinationAddr[:], packetInfo.DestinationAddr) + return p +} + +// ipv6PacketInfoToLinux converts IPv6PacketInfo from tcpip format to Linux +// format. +func ipv6PacketInfoToLinux(packetInfo tcpip.IPv6PacketInfo) linux.ControlMessageIPv6PacketInfo { + var p linux.ControlMessageIPv6PacketInfo + if n := copy(p.Addr[:], packetInfo.Addr); n != len(p.Addr) { + panic(fmt.Sprintf("got copy(%x, %x) = %d, want = %d", p.Addr, packetInfo.Addr, n, len(p.Addr))) + } + p.NIC = uint32(packetInfo.NIC) return p } @@ -114,7 +126,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa if cmgs.HasOriginalDstAddress { orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) } - return IPControlMessages{ + cm := IPControlMessages{ HasTimestamp: cmgs.HasTimestamp, Timestamp: cmgs.Timestamp, HasInq: cmgs.HasInq, @@ -125,9 +137,16 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa TClass: cmgs.TClass, HasIPPacketInfo: cmgs.HasIPPacketInfo, PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + HasIPv6PacketInfo: cmgs.HasIPv6PacketInfo, OriginalDstAddress: orgDstAddr, SockErr: sockErrCmsgToLinux(cmgs.SockErr), } + + if cm.HasIPv6PacketInfo { + cm.IPv6PacketInfo = ipv6PacketInfoToLinux(cmgs.IPv6PacketInfo) + } + + return cm } // IPControlMessages contains socket control messages for IP sockets. @@ -138,9 +157,9 @@ type IPControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packet used to create - // the read data was received. - Timestamp int64 + // Timestamp is the time that the last packet used to create the read data + // was received. + Timestamp time.Time `state:".(int64)"` // HasInq indicates whether Inq is valid/set. HasInq bool @@ -166,6 +185,12 @@ type IPControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo linux.ControlMessageIPPacketInfo + // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set. + HasIPv6PacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + IPv6PacketInfo linux.ControlMessageIPv6PacketInfo + // OriginalDestinationAddress holds the original destination address // and port of the incoming packet. OriginalDstAddress linux.SockAddr @@ -743,6 +768,8 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) + // TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have + // an ethernet address. if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } @@ -750,6 +777,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: Ntohs(a.Protocol), }, family, nil case linux.AF_UNSPEC: diff --git a/pkg/sentry/socket/socket_state.go b/pkg/sentry/socket/socket_state.go new file mode 100644 index 000000000..32e12b238 --- /dev/null +++ b/pkg/sentry/socket/socket_state.go @@ -0,0 +1,27 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package socket + +import ( + "time" +) + +func (i *IPControlMessages) saveTimestamp() int64 { + return i.Timestamp.UnixNano() +} + +func (i *IPControlMessages) loadTimestamp(nsec int64) { + i.Timestamp = time.Unix(0, nsec) +} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 5c3cdef6a..7b546c04d 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -62,7 +62,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 33f9aeb06..b3f0cf563 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -129,9 +129,9 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv stype: stype, } + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } @@ -406,14 +406,15 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { // Accept accepts a new connection. func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) { e.Lock() - defer e.Unlock() if !e.Listening() { + e.Unlock() return nil, syserr.ErrInvalidEndpointState } select { case ne := <-e.acceptedChan: + e.Unlock() if peerAddr != nil { ne.Lock() c := ne.connected @@ -429,6 +430,7 @@ func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *s return ne, nil default: + e.Unlock() // Nothing left. return nil, syserr.ErrWouldBlock } @@ -517,3 +519,6 @@ func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { } return v } + +// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters. +func (e *connectionedEndpoint) WakeupWriters() {} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 61338728a..61311718e 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,9 +44,9 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: defaultBufferSize} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } @@ -227,3 +227,6 @@ func (e *connectionlessEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { } return v } + +// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters. +func (e *connectionlessEndpoint) WakeupWriters() {} diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e4de44498..188ad3bd9 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -59,12 +59,14 @@ func (q *queue) Close() { // q.WriterQueue.Notify(waiter.WritableEvents) func (q *queue) Reset(ctx context.Context) { q.mu.Lock() - for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { - cur.Release(ctx) - } + dataList := q.dataList q.dataList.Reset() q.used = 0 q.mu.Unlock() + + for cur := dataList.Front(); cur != nil; cur = cur.Next() { + cur.Release(ctx) + } } // DecRef implements RefCounter.DecRef. @@ -133,7 +135,7 @@ func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, f free := q.limit - q.used if l > free && truncate { - if free == 0 { + if free <= 0 { // Message can't fit right now. q.mu.Unlock() return 0, false, syserr.ErrWouldBlock diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 8ccdadae9..e9e482017 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -38,7 +38,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -494,7 +493,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } n, err := src.CopyInTo(t, &w) - if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { return int(n), syserr.FromError(err) } @@ -514,13 +513,13 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b n, err = src.CopyInTo(t, &w) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -648,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } var total int64 - if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait { + if n, err := doRead(); err != linuxerr.ErrWouldBlock || dontWait { var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { @@ -683,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := doRead(); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != linuxerr.ErrWouldBlock { var from linux.SockAddr var fromLen uint32 if r.From != nil { diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 757ff2a40..4d3f4d556 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -610,9 +610,9 @@ func (i *SyscallInfo) printExit(t *kernel.Task, elapsed time.Duration, output [] if err == nil { // Fill in the output after successful execution. i.post(t, args, retval, output, LogMaximumSize) - rval = fmt.Sprintf("%#x (%v)", retval, elapsed) + rval = fmt.Sprintf("%d (%#x) (%v)", retval, retval, elapsed) } else { - rval = fmt.Sprintf("%#x errno=%d (%s) (%v)", retval, errno, err, elapsed) + rval = fmt.Sprintf("%d (%#x) errno=%d (%s) (%v)", retval, retval, errno, err, elapsed) } switch len(output) { diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD index f2c55588f..7a7c80ac6 100644 --- a/pkg/sentry/syscalls/BUILD +++ b/pkg/sentry/syscalls/BUILD @@ -16,7 +16,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/epoll", "//pkg/sentry/kernel/time", - "//pkg/syserror", "//pkg/waiter", ], ) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index b5a371d9a..394396cde 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -104,7 +104,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 76389fbe3..f4d549a3f 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) var ( @@ -90,9 +89,9 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er } // Translate error, if possible, to consolidate errors from other packages - // into a smaller set of errors from syserror package. + // into a smaller set of errors from linuxerr package. translatedErr := errOrig - if errno, ok := syserror.TranslateError(errOrig); ok { + if errno, ok := linuxerr.TranslateError(errOrig); ok { translatedErr = errno } switch { @@ -167,10 +166,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er // files. Since we have a partial read/write, we consume // ErrWouldBlock, returning the partial result. return true, nil - } - - switch errOrig.(type) { - case syserror.SyscallRestartErrno: + case linuxerr.IsRestartError(translatedErr): // Identical to the EINTR case. return true, nil } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 1ead3c7e8..2046a48b9 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/syscalls" - "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -124,7 +123,7 @@ var AMD64 = &kernel.SyscallTable{ 68: syscalls.Supported("msgget", Msgget), 69: syscalls.Supported("msgsnd", Msgsnd), 70: syscalls.Supported("msgrcv", Msgrcv), - 71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}), + 71: syscalls.Supported("msgctl", Msgctl), 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), @@ -175,8 +174,8 @@ var AMD64 = &kernel.SyscallTable{ 119: syscalls.Supported("setresgid", Setresgid), 120: syscalls.Supported("getresgid", Getresgid), 121: syscalls.Supported("getpgid", Getpgid), - 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) - 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 122: syscalls.ErrorWithEvent("setfsuid", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 123: syscalls.ErrorWithEvent("setfsgid", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) 124: syscalls.Supported("getsid", Getsid), 125: syscalls.Supported("capget", Capget), 126: syscalls.Supported("capset", Capset), @@ -187,12 +186,12 @@ var AMD64 = &kernel.SyscallTable{ 131: syscalls.Supported("sigaltstack", Sigaltstack), 132: syscalls.Supported("utime", Utime), 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil), - 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil), + 134: syscalls.Error("uselib", linuxerr.ENOSYS, "Obsolete", nil), 135: syscalls.ErrorWithEvent("personality", linuxerr.EINVAL, "Unable to change personality.", nil), - 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil), + 136: syscalls.ErrorWithEvent("ustat", linuxerr.ENOSYS, "Needs filesystem support.", nil), 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), - 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}), + 139: syscalls.ErrorWithEvent("sysfs", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/165"}), 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), @@ -230,15 +229,15 @@ var AMD64 = &kernel.SyscallTable{ 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil), 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), - 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), - 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), + 177: syscalls.Error("get_kernel_syms", linuxerr.ENOSYS, "Not supported in Linux > 2.6.", nil), + 178: syscalls.Error("query_module", linuxerr.ENOSYS, "Not supported in Linux > 2.6.", nil), 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations - 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), - 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), - 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), - 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil), - 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), - 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), + 180: syscalls.Error("nfsservctl", linuxerr.ENOSYS, "Removed after Linux 3.1.", nil), + 181: syscalls.Error("getpmsg", linuxerr.ENOSYS, "Not implemented in Linux.", nil), + 182: syscalls.Error("putpmsg", linuxerr.ENOSYS, "Not implemented in Linux.", nil), + 183: syscalls.Error("afs_syscall", linuxerr.ENOSYS, "Not implemented in Linux.", nil), + 184: syscalls.Error("tuxcall", linuxerr.ENOSYS, "Not implemented in Linux.", nil), + 185: syscalls.Error("security", linuxerr.ENOSYS, "Not implemented in Linux.", nil), 186: syscalls.Supported("gettid", Gettid), 187: syscalls.Supported("readahead", Readahead), 188: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil), @@ -258,18 +257,18 @@ var AMD64 = &kernel.SyscallTable{ 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), - 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), + 205: syscalls.Error("set_thread_area", linuxerr.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), + 211: syscalls.Error("get_thread_area", linuxerr.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), 213: syscalls.Supported("epoll_create", EpollCreate), - 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil), - 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil), - 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), + 214: syscalls.ErrorWithEvent("epoll_ctl_old", linuxerr.ENOSYS, "Deprecated.", nil), + 215: syscalls.ErrorWithEvent("epoll_wait_old", linuxerr.ENOSYS, "Deprecated.", nil), + 216: syscalls.ErrorWithEvent("remap_file_pages", linuxerr.ENOSYS, "Deprecated since Linux 3.16.", nil), 217: syscalls.Supported("getdents64", Getdents64), 218: syscalls.Supported("set_tid_address", SetTidAddress), 219: syscalls.Supported("restart_syscall", RestartSyscall), @@ -289,16 +288,16 @@ var AMD64 = &kernel.SyscallTable{ 233: syscalls.Supported("epoll_ctl", EpollCtl), 234: syscalls.Supported("tgkill", Tgkill), 235: syscalls.Supported("utimes", Utimes), - 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil), + 236: syscalls.Error("vserver", linuxerr.ENOSYS, "Not implemented by Linux", nil), 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), - 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 240: syscalls.ErrorWithEvent("mq_open", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 241: syscalls.ErrorWithEvent("mq_unlink", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 242: syscalls.ErrorWithEvent("mq_timedsend", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 243: syscalls.ErrorWithEvent("mq_timedreceive", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 244: syscalls.ErrorWithEvent("mq_notify", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 245: syscalls.ErrorWithEvent("mq_getsetattr", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), 247: syscalls.Supported("waitid", Waitid), 248: syscalls.Error("add_key", linuxerr.EACCES, "Not available to user.", nil), @@ -331,7 +330,7 @@ var AMD64 = &kernel.SyscallTable{ 275: syscalls.Supported("splice", Splice), 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), - 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 278: syscalls.ErrorWithEvent("vmsplice", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) 280: syscalls.Supported("utimensat", Utimensat), 281: syscalls.Supported("epoll_pwait", EpollPwait), @@ -353,8 +352,8 @@ var AMD64 = &kernel.SyscallTable{ 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), 298: syscalls.ErrorWithEvent("perf_event_open", linuxerr.ENODEV, "No support for perf counters", nil), 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), - 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), - 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 300: syscalls.ErrorWithEvent("fanotify_init", linuxerr.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 301: syscalls.ErrorWithEvent("fanotify_mark", linuxerr.ENOSYS, "Needs CONFIG_FANOTIFY", nil), 302: syscalls.Supported("prlimit64", Prlimit64), 303: syscalls.Error("name_to_handle_at", linuxerr.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), 304: syscalls.Error("open_by_handle_at", linuxerr.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), @@ -363,48 +362,48 @@ var AMD64 = &kernel.SyscallTable{ 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), 308: syscalls.ErrorWithEvent("setns", linuxerr.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) 309: syscalls.Supported("getcpu", Getcpu), - 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), - 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 310: syscalls.ErrorWithEvent("process_vm_readv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 311: syscalls.ErrorWithEvent("process_vm_writev", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/158"}), 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), - 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) + 314: syscalls.ErrorWithEvent("sched_setattr", linuxerr.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 315: syscalls.ErrorWithEvent("sched_getattr", linuxerr.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 316: syscalls.ErrorWithEvent("renameat2", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) 317: syscalls.Supported("seccomp", Seccomp), 318: syscalls.Supported("getrandom", GetRandom), 319: syscalls.Supported("memfd_create", MemfdCreate), 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), 322: syscalls.Supported("execveat", Execveat), - 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) + 323: syscalls.ErrorWithEvent("userfaultfd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) 324: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil), 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), // Syscalls implemented after 325 are "backports" from versions // of Linux after 4.4. - 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), + 326: syscalls.ErrorWithEvent("copy_file_range", linuxerr.ENOSYS, "", nil), 327: syscalls.Supported("preadv2", Preadv2), 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), - 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), - 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), - 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), + 329: syscalls.ErrorWithEvent("pkey_mprotect", linuxerr.ENOSYS, "", nil), + 330: syscalls.ErrorWithEvent("pkey_alloc", linuxerr.ENOSYS, "", nil), + 331: syscalls.ErrorWithEvent("pkey_free", linuxerr.ENOSYS, "", nil), 332: syscalls.Supported("statx", Statx), - 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 333: syscalls.ErrorWithEvent("io_pgetevents", linuxerr.ENOSYS, "", nil), 334: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil), // Linux skips ahead to syscall 424 to sync numbers between arches. - 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), - 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), - 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), - 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), - 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), - 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), - 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), - 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), - 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), - 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), - 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), - 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 424: syscalls.ErrorWithEvent("pidfd_send_signal", linuxerr.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", linuxerr.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", linuxerr.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", linuxerr.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", linuxerr.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", linuxerr.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", linuxerr.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", linuxerr.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", linuxerr.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", linuxerr.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", linuxerr.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", linuxerr.ENOSYS, "", nil), 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{ @@ -414,7 +413,7 @@ var AMD64 = &kernel.SyscallTable{ }, Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { t.Kernel().EmitUnimplementedEvent(t) - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS }, } @@ -472,7 +471,7 @@ var ARM64 = &kernel.SyscallTable{ 39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), 40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), 41: syscalls.Error("pivot_root", linuxerr.EPERM, "", nil), - 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), + 42: syscalls.Error("nfsservctl", linuxerr.ENOSYS, "Removed after Linux 3.1.", nil), 43: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), 44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), 45: syscalls.Supported("truncate", Truncate), @@ -505,7 +504,7 @@ var ARM64 = &kernel.SyscallTable{ 72: syscalls.Supported("pselect", Pselect), 73: syscalls.Supported("ppoll", Ppoll), 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), - 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 75: syscalls.ErrorWithEvent("vmsplice", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) 76: syscalls.Supported("splice", Splice), 77: syscalls.Supported("tee", Tee), 78: syscalls.Supported("readlinkat", Readlinkat), @@ -581,8 +580,8 @@ var ARM64 = &kernel.SyscallTable{ 148: syscalls.Supported("getresuid", Getresuid), 149: syscalls.Supported("setresgid", Setresgid), 150: syscalls.Supported("getresgid", Getresgid), - 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) - 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 151: syscalls.ErrorWithEvent("setfsuid", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 152: syscalls.ErrorWithEvent("setfsgid", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) 153: syscalls.Supported("times", Times), 154: syscalls.Supported("setpgid", Setpgid), 155: syscalls.Supported("getpgid", Getpgid), @@ -610,14 +609,14 @@ var ARM64 = &kernel.SyscallTable{ 177: syscalls.Supported("getegid", Getegid), 178: syscalls.Supported("gettid", Gettid), 179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), - 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 180: syscalls.ErrorWithEvent("mq_open", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 181: syscalls.ErrorWithEvent("mq_unlink", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 182: syscalls.ErrorWithEvent("mq_timedsend", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 183: syscalls.ErrorWithEvent("mq_timedreceive", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 184: syscalls.ErrorWithEvent("mq_notify", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 185: syscalls.ErrorWithEvent("mq_getsetattr", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) 186: syscalls.Supported("msgget", Msgget), - 187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}), + 187: syscalls.Supported("msgctl", Msgctl), 188: syscalls.Supported("msgrcv", Msgrcv), 189: syscalls.Supported("msgsnd", Msgsnd), 190: syscalls.Supported("semget", Semget), @@ -664,7 +663,7 @@ var ARM64 = &kernel.SyscallTable{ 231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), 232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), 233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), - 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), + 234: syscalls.ErrorWithEvent("remap_file_pages", linuxerr.ENOSYS, "Deprecated since Linux 3.16.", nil), 235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), 236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), 237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), @@ -676,60 +675,60 @@ var ARM64 = &kernel.SyscallTable{ 243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), 260: syscalls.Supported("wait4", Wait4), 261: syscalls.Supported("prlimit64", Prlimit64), - 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), - 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 262: syscalls.ErrorWithEvent("fanotify_init", linuxerr.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 263: syscalls.ErrorWithEvent("fanotify_mark", linuxerr.ENOSYS, "Needs CONFIG_FANOTIFY", nil), 264: syscalls.Error("name_to_handle_at", linuxerr.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), 265: syscalls.Error("open_by_handle_at", linuxerr.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), 266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), 267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), 268: syscalls.ErrorWithEvent("setns", linuxerr.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) 269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), - 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), - 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 270: syscalls.ErrorWithEvent("process_vm_readv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 271: syscalls.ErrorWithEvent("process_vm_writev", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/158"}), 272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), 273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), - 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) + 274: syscalls.ErrorWithEvent("sched_setattr", linuxerr.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 275: syscalls.ErrorWithEvent("sched_getattr", linuxerr.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 276: syscalls.ErrorWithEvent("renameat2", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) 277: syscalls.Supported("seccomp", Seccomp), 278: syscalls.Supported("getrandom", GetRandom), 279: syscalls.Supported("memfd_create", MemfdCreate), 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), 281: syscalls.Supported("execveat", Execveat), - 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) + 282: syscalls.ErrorWithEvent("userfaultfd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) 283: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil), 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), // Syscalls after 284 are "backports" from versions of Linux after 4.4. - 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), + 285: syscalls.ErrorWithEvent("copy_file_range", linuxerr.ENOSYS, "", nil), 286: syscalls.Supported("preadv2", Preadv2), 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), - 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), - 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), - 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), + 288: syscalls.ErrorWithEvent("pkey_mprotect", linuxerr.ENOSYS, "", nil), + 289: syscalls.ErrorWithEvent("pkey_alloc", linuxerr.ENOSYS, "", nil), + 290: syscalls.ErrorWithEvent("pkey_free", linuxerr.ENOSYS, "", nil), 291: syscalls.Supported("statx", Statx), - 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 292: syscalls.ErrorWithEvent("io_pgetevents", linuxerr.ENOSYS, "", nil), 293: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil), // Linux skips ahead to syscall 424 to sync numbers between arches. - 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), - 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), - 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), - 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), - 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), - 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), - 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), - 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), - 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), - 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), - 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), - 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 424: syscalls.ErrorWithEvent("pidfd_send_signal", linuxerr.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", linuxerr.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", linuxerr.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", linuxerr.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", linuxerr.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", linuxerr.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", linuxerr.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", linuxerr.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", linuxerr.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", linuxerr.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", linuxerr.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", linuxerr.ENOSYS, "", nil), 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{}, Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { t.Kernel().EmitUnimplementedEvent(t) - return 0, syserror.ENOSYS + return 0, linuxerr.ENOSYS }, } diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go index 9dea78085..373948991 100644 --- a/pkg/sentry/syscalls/linux/sigset.go +++ b/pkg/sentry/syscalls/linux/sigset.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // CopyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and @@ -67,6 +66,6 @@ func copyInSigSetWithSize(t *kernel.Task, addr hostarch.Addr) (hostarch.Addr, ui maskSize := uint(hostarch.ByteOrder.Uint64(in[8:])) return maskAddr, maskSize, nil default: - return 0, 0, syserror.ENOSYS + return 0, 0, linuxerr.ENOSYS } } diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index 4ce3430e2..2f00c3783 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -26,7 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/usermem" ) @@ -138,7 +138,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if count > 0 || linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return uintptr(count), nil, nil } - return 0, nil, syserror.ConvertIntr(err, syserror.EINTR) + return 0, nil, syserr.ConvertIntr(err, linuxerr.EINTR) } } @@ -216,7 +216,7 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) // It is not presently supported (ENOSYS indicates no support on this // architecture). func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } // LINT.IfChange @@ -355,7 +355,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } cbAddr = hostarch.Addr(cbAddrP) default: - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } // Copy in this callback. diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index daa151bb4..6c807124c 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -22,7 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" "gvisor.dev/gvisor/pkg/sentry/syscalls" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/waiter" ) @@ -109,7 +109,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func waitEpoll(t *kernel.Task, fd int32, eventsAddr hostarch.Addr, max int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { r, err := syscalls.WaitEpoll(t, fd, max, timeoutInNanos) if err != nil { - return 0, nil, syserror.ConvertIntr(err, syserror.EINTR) + return 0, nil, syserr.ConvertIntr(err, linuxerr.EINTR) } if len(r) != 0 { diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 3528d325f..e79b92fb6 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -30,7 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/fasync" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" ) // fileOpAt performs an operation on the second last component in the path. @@ -122,7 +122,7 @@ func copyInPath(t *kernel.Task, addr hostarch.Addr, allowEmpty bool) (path strin return "", false, err } if path == "" && !allowEmpty { - return "", false, syserror.ENOENT + return "", false, linuxerr.ENOENT } // If the path ends with a /, then checks must be enforced in various @@ -162,7 +162,7 @@ func openAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint) (fd uin if fs.IsDir(d.Inode.StableAttr) { // Don't allow directories to be opened writable. if fileFlags.Write { - return syserror.EISDIR + return linuxerr.EISDIR } } else { // If O_DIRECTORY is set, but the file is not a directory, then fail. @@ -177,7 +177,7 @@ func openAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint) (fd uin file, err := d.Inode.GetFile(t, d, fileFlags) if err != nil { - return syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } defer file.DecRef(t) @@ -215,7 +215,7 @@ func mknodAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMod return err } if dirPath { - return syserror.ENOENT + return linuxerr.ENOENT } return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error { @@ -308,7 +308,7 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode return 0, err } if dirPath { - return 0, syserror.ENOENT + return 0, linuxerr.ENOENT } fileFlags := linuxToFlags(flags) @@ -416,7 +416,7 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode // Create a new fs.File. newFile, err = found.Inode.GetFile(t, found, fileFlags) if err != nil { - return syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } defer newFile.DecRef(t) case linuxerr.Equals(linuxerr.ENOENT, err): @@ -795,7 +795,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall defer file.DecRef(t) err := file.Flush(t) - return 0, nil, handleIOError(t, false /* partial */, err, syserror.EINTR, "close", file) + return 0, nil, handleIOError(t, false /* partial */, err, linuxerr.EINTR, "close", file) } // Dup implements linux syscall dup(2). @@ -1020,7 +1020,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } else { // Blocking lock, pass in the task to satisfy the lock.Blocker interface. if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.ReadLock, rng, t) { - return 0, nil, syserror.EINTR + return 0, nil, linuxerr.EINTR } } return 0, nil, nil @@ -1036,7 +1036,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } else { // Blocking lock, pass in the task to satisfy the lock.Blocker interface. if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.WriteLock, rng, t) { - return 0, nil, syserror.EINTR + return 0, nil, linuxerr.EINTR } } return 0, nil, nil @@ -1263,7 +1263,7 @@ func symlinkAt(t *kernel.Task, dirFD int32, newAddr hostarch.Addr, oldAddr hosta return err } if dirPath { - return syserror.ENOENT + return linuxerr.ENOENT } // The oldPath is copied in verbatim. This is because the symlink @@ -1273,7 +1273,7 @@ func symlinkAt(t *kernel.Task, dirFD int32, newAddr hostarch.Addr, oldAddr hosta return err } if oldPath == "" { - return syserror.ENOENT + return linuxerr.ENOENT } return fileOpAt(t, dirFD, newPath, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error { @@ -1352,7 +1352,7 @@ func linkAt(t *kernel.Task, oldDirFD int32, oldAddr hostarch.Addr, newDirFD int3 return err } if dirPath { - return syserror.ENOENT + return linuxerr.ENOENT } if allowEmpty && oldPath == "" { @@ -1439,7 +1439,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal allowEmpty := flags&linux.AT_EMPTY_PATH == linux.AT_EMPTY_PATH if allowEmpty && !t.HasCapabilityIn(linux.CAP_DAC_READ_SEARCH, t.UserNamespace().Root()) { - return 0, nil, syserror.ENOENT + return 0, nil, linuxerr.ENOENT } return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty) @@ -1455,7 +1455,7 @@ func readlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, bufAddr hostarc return 0, err } if dirPath { - return 0, syserror.ENOENT + return 0, linuxerr.ENOENT } err = fileOpOn(t, dirFD, path, false /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { @@ -1579,7 +1579,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { if fs.IsDir(d.Inode.StableAttr) { - return syserror.EISDIR + return linuxerr.EISDIR } // In contrast to open(O_TRUNC), truncate(2) is only valid for file // types. @@ -2131,7 +2131,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, linuxerr.ESPIPE } if fs.IsDir(file.Dirent.Inode.StableAttr) { - return 0, nil, syserror.EISDIR + return 0, nil, linuxerr.EISDIR } if !fs.IsRegular(file.Dirent.Inode.StableAttr) { return 0, nil, linuxerr.ENODEV @@ -2189,7 +2189,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } else { // Because we're blocking we will pass the task to satisfy the lock.Blocker interface. if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.WriteLock, rng, t) { - return 0, nil, syserror.EINTR + return 0, nil, linuxerr.EINTR } } case linux.LOCK_SH: @@ -2201,7 +2201,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } else { // Because we're blocking we will pass the task to satisfy the lock.Blocker interface. if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.ReadLock, rng, t) { - return 0, nil, syserror.EINTR + return 0, nil, linuxerr.EINTR } } case linux.LOCK_UN: diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index 717cec04d..bcdd7b633 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -23,7 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" ) // futexWaitRestartBlock encapsulates the state required to restart futex(2) @@ -75,7 +75,7 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo } t.Futex().WaitComplete(w, t) - return 0, syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } // futexWaitDuration performs a FUTEX_WAIT, blocking until the wait is @@ -103,7 +103,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add // The wait was unsuccessful for some reason other than interruption. Simply // forward the error. - if err != syserror.ErrInterrupted { + if err != linuxerr.ErrInterrupted { return 0, err } @@ -111,7 +111,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add // The wait duration was absolute, restart with the original arguments. if forever { - return 0, syserror.ERESTARTSYS + return 0, linuxerr.ERESTARTSYS } // The wait duration was relative, restart with the remaining duration. @@ -122,7 +122,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add val: val, mask: mask, }) - return 0, syserror.ERESTART_RESTARTBLOCK + return 0, linuxerr.ERESTART_RESTARTBLOCK } func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr hostarch.Addr, private bool) error { @@ -150,7 +150,7 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr hostarch. } t.Futex().WaitComplete(w, t) - return syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } func tryLockPI(t *kernel.Task, addr hostarch.Addr, private bool) error { @@ -280,11 +280,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FUTEX_WAIT_REQUEUE_PI, linux.FUTEX_CMP_REQUEUE_PI: t.Kernel().EmitUnimplementedEvent(t) - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS default: // We don't even know about this command. - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } } diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index 917717e31..9f7a5ae8a 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -83,7 +82,7 @@ func getdents(t *kernel.Task, fd int32, addr hostarch.Addr, size int, f func(*di ds := newDirentSerializer(f, w, t.Arch(), size) rerr := dir.Readdir(t, ds) - switch err := handleIOError(t, ds.Written() > 0, rerr, syserror.ERESTARTSYS, "getdents", dir); err { + switch err := handleIOError(t, ds.Written() > 0, rerr, linuxerr.ERESTARTSYS, "getdents", dir); err { case nil: dir.Dirent.InotifyEvent(linux.IN_ACCESS, 0) return uintptr(ds.Written()), nil diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go index bf71a9af3..4a5712a29 100644 --- a/pkg/sentry/syscalls/linux/sys_lseek.go +++ b/pkg/sentry/syscalls/linux/sys_lseek.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // LINT.IfChange @@ -49,7 +48,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } offset, serr := file.Seek(t, sw, offset) - err := handleIOError(t, false /* partialResult */, serr, syserror.ERESTARTSYS, "lseek", file) + err := handleIOError(t, false /* partialResult */, serr, linuxerr.ERESTARTSYS, "lseek", file) if err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index cee621791..7efd17d40 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -24,7 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" ) // Brk implements linux syscall brk(2). @@ -211,7 +211,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca case linux.MADV_REMOVE: // These "suggestions" have application-visible side effects, so we // have to indicate that we don't support them. - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS case linux.MADV_HWPOISON: // Only privileged processes are allowed to poison pages. return 0, nil, linuxerr.EPERM @@ -235,18 +235,18 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // rounded up to the next multiple of the page size." - mincore(2) la, ok := hostarch.Addr(length).RoundUp() if !ok { - return 0, nil, syserror.ENOMEM + return 0, nil, linuxerr.ENOMEM } ar, ok := addr.ToRange(uint64(la)) if !ok { - return 0, nil, syserror.ENOMEM + return 0, nil, linuxerr.ENOMEM } // Pretend that all mapped pages are "resident in core". mapped := t.MemoryManager().VirtualMemorySizeRange(ar) // "ENOMEM: addr to addr + length contained unmapped memory." if mapped != uint64(la) { - return 0, nil, syserror.ENOMEM + return 0, nil, linuxerr.ENOMEM } resident := bytes.Repeat([]byte{1}, int(mapped/hostarch.PageSize)) _, err := t.CopyOutBytes(vec, resident) @@ -277,7 +277,7 @@ func Msync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall }) // MSync calls fsync, the same interrupt conversion rules apply, see // mm/msync.c, fsync POSIX.1-2008. - return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } // Mlock implements linux syscall mlock(2). diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go index 5259ade90..60b989ee7 100644 --- a/pkg/sentry/syscalls/linux/sys_msgqueue.go +++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go @@ -130,12 +130,63 @@ func receive(t *kernel.Task, id ipc.ID, mType int64, maxSize int64, msgCopy, wai func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { id := ipc.ID(args[0].Int()) cmd := args[1].Int() + buf := args[2].Pointer() creds := auth.CredentialsFromContext(t) + r := t.IPCNamespace().MsgqueueRegistry() + switch cmd { + case linux.IPC_INFO: + info := r.IPCInfo(t) + _, err := info.CopyOut(t, buf) + return 0, nil, err + case linux.MSG_INFO: + msgInfo := r.MsgInfo(t) + _, err := msgInfo.CopyOut(t, buf) + return 0, nil, err case linux.IPC_RMID: - return 0, nil, t.IPCNamespace().MsgqueueRegistry().Remove(id, creds) + return 0, nil, r.Remove(id, creds) + } + + // Remaining commands use a queue. + queue, err := r.FindByID(id) + if err != nil { + return 0, nil, err + } + + switch cmd { + case linux.MSG_STAT: + // Technically, we should be treating id as "an index into the kernel's + // internal array that maintains information about all shared memory + // segments on the system". Since we don't track segments in an array, + // we'll just pretend the msqid is the index and do the same thing as + // IPC_STAT. Linux also uses the index as the msqid. + fallthrough + case linux.IPC_STAT: + stat, err := queue.Stat(t) + if err != nil { + return 0, nil, err + } + _, err = stat.CopyOut(t, buf) + return 0, nil, err + + case linux.MSG_STAT_ANY: + stat, err := queue.StatAny(t) + if err != nil { + return 0, nil, err + } + _, err = stat.CopyOut(t, buf) + return 0, nil, err + + case linux.IPC_SET: + var ds linux.MsqidDS + if _, err := ds.CopyIn(t, buf); err != nil { + return 0, nil, linuxerr.EINVAL + } + err := queue.Set(t, &ds) + return 0, nil, err + default: return 0, nil, linuxerr.EINVAL } diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index a80c84fcd..ee4dbbc64 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -25,7 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/waiter" ) @@ -185,7 +185,7 @@ func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration pfd[i].Events |= linux.POLLHUP | linux.POLLERR } remainingTimeout, n, err := pollBlock(t, pfd, timeout) - err = syserror.ConvertIntr(err, syserror.EINTR) + err = syserr.ConvertIntr(err, linuxerr.EINTR) // The poll entries are copied out regardless of whether // any are set or not. This aligns with the Linux behavior. @@ -295,7 +295,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Ad // Do the syscall, then count the number of bits set. if _, _, err = pollBlock(t, pfd, timeout); err != nil { - return 0, syserror.ConvertIntr(err, syserror.EINTR) + return 0, syserr.ConvertIntr(err, linuxerr.EINTR) } // r, w, and e are currently event mask bitsets; unset bits corresponding @@ -411,7 +411,7 @@ func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duratio nfds: nfds, timeout: remainingTimeout, }) - return 0, syserror.ERESTART_RESTARTBLOCK + return 0, linuxerr.ERESTART_RESTARTBLOCK } return n, err } @@ -465,7 +465,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { - err = syserror.ERESTARTNOHAND + err = linuxerr.ERESTARTNOHAND } return n, nil, err } @@ -495,7 +495,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { - err = syserror.ERESTARTNOHAND + err = linuxerr.ERESTARTNOHAND } return n, nil, err } @@ -540,7 +540,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { - err = syserror.ERESTARTNOHAND + err = linuxerr.ERESTARTNOHAND } return n, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index b54a3a11f..18ea23913 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -72,7 +71,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "read", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "read", file) } // Readahead implements readahead(2). @@ -152,7 +151,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pread64", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pread64", file) } // Readv implements linux syscall readv(2). @@ -182,7 +181,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "readv", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "readv", file) } // Preadv implements linux syscall preadv(2). @@ -223,7 +222,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv", file) } // Preadv2 implements linux syscall preadv2(2). @@ -281,17 +280,17 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if offset == -1 { n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv2", file) } n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv2", file) } func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) { n, err := f.Readv(t, dst) - if err != syserror.ErrWouldBlock || f.Flags().NonBlocking { + if err != linuxerr.ErrWouldBlock || f.Flags().NonBlocking { if n > 0 { // Queue notification if we read anything. f.Dirent.InotifyEvent(linux.IN_ACCESS, 0) @@ -304,7 +303,7 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) { var deadline ktime.Time if s, ok := f.FileOperations.(socket.Socket); ok { dl := s.RecvTimeout() - if dl < 0 && err == syserror.ErrWouldBlock { + if dl < 0 && err == linuxerr.ErrWouldBlock { return n, err } if dl > 0 { @@ -326,14 +325,14 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) { // other than "would block". n, err = f.Readv(t, dst) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -351,7 +350,7 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) { func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { n, err := f.Preadv(t, dst, offset) - if err != syserror.ErrWouldBlock || f.Flags().NonBlocking { + if err != linuxerr.ErrWouldBlock || f.Flags().NonBlocking { if n > 0 { // Queue notification if we read anything. f.Dirent.InotifyEvent(linux.IN_ACCESS, 0) @@ -372,7 +371,7 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i // other than "would block". n, err = f.Preadv(t, dst, offset+total) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index a12e1c915..7210333d2 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/syserror" ) // rlimit describes an implementation of 'struct rlimit', which may vary from @@ -44,7 +43,7 @@ func newRlimit(t *kernel.Task) (rlimit, error) { // On 64-bit system, struct rlimit and struct rlimit64 are identical. return &rlimit64{}, nil default: - return nil, syserror.ENOSYS + return nil, linuxerr.ENOSYS } } diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go index 5fe196647..8328a3742 100644 --- a/pkg/sentry/syscalls/linux/sys_rseq.go +++ b/pkg/sentry/syscalls/linux/sys_rseq.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // RSeq implements syscall rseq(2). @@ -33,7 +32,7 @@ func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Event for applications that want rseq on a configuration // that doesn't support them. t.Kernel().EmitUnimplementedEvent(t) - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } switch flags { diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index f61cc466c..5a119b21c 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/ipc" @@ -166,8 +165,7 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err } - perms := fs.FilePermsFromMode(linux.FileMode(s.SemPerm.Mode & 0777)) - return 0, nil, ipcSet(t, id, auth.UID(s.SemPerm.UID), auth.GID(s.SemPerm.GID), perms) + return 0, nil, ipcSet(t, id, &s) case linux.GETPID: v, err := getPID(t, id, num) @@ -243,24 +241,13 @@ func remove(t *kernel.Task, id ipc.ID) error { return r.Remove(id, creds) } -func ipcSet(t *kernel.Task, id ipc.ID, uid auth.UID, gid auth.GID, perms fs.FilePermissions) error { +func ipcSet(t *kernel.Task, id ipc.ID, ds *linux.SemidDS) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { return linuxerr.EINVAL } - - creds := auth.CredentialsFromContext(t) - kuid := creds.UserNamespace.MapToKUID(uid) - if !kuid.Ok() { - return linuxerr.EINVAL - } - kgid := creds.UserNamespace.MapToKGID(gid) - if !kgid.Ok() { - return linuxerr.EINVAL - } - owner := fs.FileOwner{UID: kuid, GID: kgid} - return set.Change(t, creds, owner, perms) + return set.Set(t, ds) } func ipcStat(t *kernel.Task, id ipc.ID) (*linux.SemidDS, error) { diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 45608f3fa..03871d713 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -25,7 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" ) // "For a process to have permission to send a signal it must @@ -348,7 +348,7 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // Pause implements linux syscall pause(2). func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - return 0, nil, syserror.ConvertIntr(t.Block(nil), syserror.ERESTARTNOHAND) + return 0, nil, syserr.ConvertIntr(t.Block(nil), linuxerr.ERESTARTNOHAND) } // RtSigpending implements linux syscall rt_sigpending(2). @@ -496,7 +496,7 @@ func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. t.SetSavedSignalMask(oldmask) // Perform the wait. - return 0, nil, syserror.ConvertIntr(t.Block(nil), syserror.ERESTARTNOHAND) + return 0, nil, syserr.ConvertIntr(t.Block(nil), linuxerr.ERESTARTNOHAND) } // RestartSyscall implements the linux syscall restart_syscall(2). @@ -512,7 +512,7 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne // function is never null by (re)initializing it with one that translates // the restart into EINTR. We'll emulate that behaviour. t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?") - return 0, nil, syserror.EINTR + return 0, nil, linuxerr.EINTR } // sharedSignalfd is shared between the two calls. diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 06eb8f319..50ddbc142 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -260,7 +259,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Capture address and call syscall implementation. @@ -270,7 +269,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } blocking := !file.Flags().NonBlocking - return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS) + return 0, nil, syserr.ConvertIntr(s.Connect(t, a, blocking).ToError(), linuxerr.ERESTARTSYS) } // accept is the implementation of the accept syscall. It is called by accept @@ -291,7 +290,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, syserror.ENOTSOCK + return 0, linuxerr.ENOTSOCK } // Call the syscall implementation for this socket, then copy the @@ -301,7 +300,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, peerRequested := addrLen != 0 nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS) } if peerRequested { // NOTE(magi): Linux does not give you an error if it can't @@ -350,7 +349,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Capture address and call syscall implementation. @@ -377,7 +376,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } if backlog > maxListenBacklog { @@ -415,7 +414,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Validate how, then call syscall implementation. @@ -446,7 +445,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Read the length. Reject negative values. @@ -527,7 +526,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } if optLen < 0 { @@ -565,7 +564,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Get the socket name and copy it to the caller. @@ -593,7 +592,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Get the socket peer name and copy it to the caller. @@ -626,7 +625,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Reject flags that we don't handle yet. @@ -683,7 +682,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } if file.Flags().NonBlocking { @@ -763,7 +762,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) if err != nil { - return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(err.ToError(), linuxerr.ERESTARTSYS) } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC @@ -785,7 +784,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags } n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS) } defer cms.Release(t) @@ -848,7 +847,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, fla // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, syserror.ENOTSOCK + return 0, linuxerr.ENOTSOCK } if file.Flags().NonBlocking { @@ -874,7 +873,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, fla n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) cm.Release(t) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS) } // Copy the address to the caller. @@ -921,7 +920,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Reject flags that we don't handle yet. @@ -963,7 +962,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Reject flags that we don't handle yet. @@ -1060,7 +1059,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr hostar // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) - err = handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file) + err = handleIOError(t, n != 0, e.ToError(), linuxerr.ERESTARTSYS, "sendmsg", file) // Control messages should be released on error as well as for zero-length // messages, which are discarded by the receiver. if n == 0 || err != nil { @@ -1087,7 +1086,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags // Extract the socket. s, ok := file.FileOperations.(socket.Socket) if !ok { - return 0, syserror.ENOTSOCK + return 0, linuxerr.ENOTSOCK } if file.Flags().NonBlocking { @@ -1122,7 +1121,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) - return uintptr(n), handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file) + return uintptr(n), handleIOError(t, n != 0, e.ToError(), linuxerr.ERESTARTSYS, "sendto", file) } // SendTo implements the linux syscall sendto(2). diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 34d87ac1f..8c8847efa 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -46,9 +45,9 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB for { n, err = fs.Splice(t, outFile, inFile, opts) - if n != 0 || err != syserror.ErrWouldBlock { + if n != 0 || err != linuxerr.ErrWouldBlock { break - } else if err == syserror.ErrWouldBlock && nonBlocking { + } else if err == linuxerr.ErrWouldBlock && nonBlocking { break } @@ -177,7 +176,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // We can only pass a single file to handleIOError, so pick inFile // arbitrarily. This is used only for debugging purposes. - return uintptr(n), nil, handleIOError(t, false, err, syserror.ERESTARTSYS, "sendfile", inFile) + return uintptr(n), nil, handleIOError(t, false, err, linuxerr.ERESTARTSYS, "sendfile", inFile) } // Splice implements splice(2). @@ -287,7 +286,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "splice", inFile) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "splice", inFile) } // Tee imlements tee(2). @@ -340,5 +339,5 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo } // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, false, err, syserror.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, handleIOError(t, false, err, linuxerr.ERESTARTSYS, "tee", inFile) } diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go index 6278bef21..0c22599bf 100644 --- a/pkg/sentry/syscalls/linux/sys_sync.go +++ b/pkg/sentry/syscalls/linux/sys_sync.go @@ -20,7 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" ) // LINT.IfChange @@ -58,7 +58,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall defer file.DecRef(t) err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll) - return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } // Fdatasync implements linux syscall fdatasync(2). @@ -74,7 +74,7 @@ func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys defer file.DecRef(t) err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData) - return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } // SyncFileRange implements linux syscall sync_file_rage(2) @@ -112,7 +112,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel if uflags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 && uflags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 { t.Kernel().EmitUnimplementedEvent(t) - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } // SYNC_FILE_RANGE_WRITE initiates write-out of all dirty pages in the @@ -137,7 +137,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel err = file.Fsync(t, offset, fs.FileMaxOffset, fs.SyncData) } - return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } // LINT.ThenChange(vfs2/sync.go) diff --git a/pkg/sentry/syscalls/linux/sys_syslog.go b/pkg/sentry/syscalls/linux/sys_syslog.go index ba372f9e3..15acb2b8b 100644 --- a/pkg/sentry/syscalls/linux/sys_syslog.go +++ b/pkg/sentry/syscalls/linux/sys_syslog.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -57,6 +56,6 @@ func Syslog(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal case _SYSLOG_ACTION_SIZE_BUFFER: return logBufLen, nil, nil default: - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 981cdd985..d74173c56 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" "gvisor.dev/gvisor/pkg/sentry/loader" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -111,7 +110,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr host } atEmptyPath := flags&linux.AT_EMPTY_PATH != 0 if !atEmptyPath && len(pathname) == 0 { - return 0, nil, syserror.ENOENT + return 0, nil, linuxerr.ENOENT } resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0 @@ -244,7 +243,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error { wopts.Events |= kernel.EventGroupContinue } if options&linux.WNOHANG == 0 { - wopts.BlockInterruptErr = syserror.ERESTARTSYS + wopts.BlockInterruptErr = linuxerr.ERESTARTSYS } if options&linux.WNOTHREAD == 0 { wopts.SiblingChildren = true diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index 674e74f82..4adc8b8a4 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/syserror" ) // The most significant 29 bits hold either a pid or a file descriptor. @@ -214,7 +213,7 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, end ktime.Time, rem host case linuxerr.Equals(linuxerr.ETIMEDOUT, err): // Slept for entire timeout. return nil - case err == syserror.ErrInterrupted: + case err == linuxerr.ErrInterrupted: // Interrupted. remaining := end.Sub(c.Now()) if remaining <= 0 { @@ -235,9 +234,9 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, end ktime.Time, rem host end: end, rem: rem, }) - return syserror.ERESTART_RESTARTBLOCK + return linuxerr.ERESTART_RESTARTBLOCK } - return syserror.ERESTARTNOHAND + return linuxerr.ERESTARTNOHAND default: panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err)) } diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go index 45eef4feb..d39a0a6f5 100644 --- a/pkg/sentry/syscalls/linux/sys_timer.go +++ b/pkg/sentry/syscalls/linux/sys_timer.go @@ -18,9 +18,9 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) const nsecPerSec = int64(time.Second) @@ -29,7 +29,7 @@ const nsecPerSec = int64(time.Second) func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { if t.Arch().Width() != 8 { // Definition of linux.ItimerVal assumes 64-bit architecture. - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } timerID := args[0].Int() @@ -51,7 +51,7 @@ func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { if t.Arch().Width() != 8 { // Definition of linux.ItimerVal assumes 64-bit architecture. - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } timerID := args[0].Int() diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go index 8c6cd7511..bde672d67 100644 --- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go +++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // ArchPrctl implements linux syscall arch_prctl(2). @@ -39,7 +38,7 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } default: - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } case linux.ARCH_SET_FS: fsbase := args[1].Uint64() diff --git a/pkg/sentry/syscalls/linux/sys_tls_arm64.go b/pkg/sentry/syscalls/linux/sys_tls_arm64.go index ff4ac4d6d..dfa684387 100644 --- a/pkg/sentry/syscalls/linux/sys_tls_arm64.go +++ b/pkg/sentry/syscalls/linux/sys_tls_arm64.go @@ -18,12 +18,12 @@ package linux import ( + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // ArchPrctl is not defined for ARM64. func ArchPrctl(*kernel.Task, arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 872168606..4a4ef5046 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -72,7 +71,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "write", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "write", file) } // Pwrite64 implements linux syscall pwrite64(2). @@ -119,7 +118,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwrite64", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pwrite64", file) } // Writev implements linux syscall writev(2). @@ -149,7 +148,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "writev", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "writev", file) } // Pwritev implements linux syscall pwritev(2). @@ -190,7 +189,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev", file) } // Pwritev2 implements linux syscall pwritev2(2). @@ -251,17 +250,17 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if offset == -1 { n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev2", file) } n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev2", file) } func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) { n, err := f.Writev(t, src) - if err != syserror.ErrWouldBlock || f.Flags().NonBlocking { + if err != linuxerr.ErrWouldBlock || f.Flags().NonBlocking { if n > 0 { // Queue notification if we wrote anything. f.Dirent.InotifyEvent(linux.IN_MODIFY, 0) @@ -274,7 +273,7 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) { var deadline ktime.Time if s, ok := f.FileOperations.(socket.Socket); ok { dl := s.SendTimeout() - if dl < 0 && err == syserror.ErrWouldBlock { + if dl < 0 && err == linuxerr.ErrWouldBlock { return n, err } if dl > 0 { @@ -296,14 +295,14 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) { // anything other than "would block". n, err = f.Writev(t, src) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -321,7 +320,7 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) { func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (int64, error) { n, err := f.Pwritev(t, src, offset) - if err != syserror.ErrWouldBlock || f.Flags().NonBlocking { + if err != linuxerr.ErrWouldBlock || f.Flags().NonBlocking { if n > 0 { // Queue notification if we wrote anything. f.Dirent.InotifyEvent(linux.IN_MODIFY, 0) @@ -342,7 +341,7 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) ( // anything other than "would block". n, err = f.Pwritev(t, src, offset+total) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go index b327e27d6..d90652a3f 100644 --- a/pkg/sentry/syscalls/linux/timespec.go +++ b/pkg/sentry/syscalls/linux/timespec.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // copyTimespecIn copies a Timespec from the untrusted app range to the kernel. @@ -38,7 +37,7 @@ func copyTimespecIn(t *kernel.Task, addr hostarch.Addr) (linux.Timespec, error) ts.Nsec = int64(hostarch.ByteOrder.Uint64(in[8:])) return ts, nil default: - return linux.Timespec{}, syserror.ENOSYS + return linux.Timespec{}, linuxerr.ENOSYS } } @@ -52,7 +51,7 @@ func copyTimespecOut(t *kernel.Task, addr hostarch.Addr, ts *linux.Timespec) err _, err := t.CopyOutBytes(addr, out) return err default: - return syserror.ENOSYS + return linuxerr.ENOSYS } } @@ -70,7 +69,7 @@ func copyTimevalIn(t *kernel.Task, addr hostarch.Addr) (linux.Timeval, error) { tv.Usec = int64(hostarch.ByteOrder.Uint64(in[8:])) return tv, nil default: - return linux.Timeval{}, syserror.ENOSYS + return linux.Timeval{}, linuxerr.ENOSYS } } @@ -84,7 +83,7 @@ func copyTimevalOut(t *kernel.Task, addr hostarch.Addr, tv *linux.Timeval) error _, err := t.CopyOutBytes(addr, out) return err default: - return syserror.ENOSYS + return linuxerr.ENOSYS } } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index a73f096ff..1e3bd2a50 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -73,7 +73,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index a8fa86cdc..0b57c0f7c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -56,7 +55,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } cbAddr = hostarch.Addr(cbAddrP) default: - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } // Copy in this callback. diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go index 38818c175..fcf2e25de 100644 --- a/pkg/sentry/syscalls/linux/vfs2/execve.go +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/loader" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // Execve implements linux syscall execve(2). @@ -83,7 +82,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr host // do_open_execat(fd=AT_FDCWD)), and the loader package is currently // incapable of handling this correctly. if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { - return 0, nil, syserror.ENOENT + return 0, nil, linuxerr.ENOENT } dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd) if dirfile == nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 2cfb12cad..2198aa065 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // Close implements Linux syscall close(2). @@ -42,7 +41,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall defer file.DecRef(t) err := file.OnClose(t) - return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file) + return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, linuxerr.EINTR, "close", file) } // Dup implements Linux syscall dup(2). diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index 534355237..f19f0fd41 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // Link implements Linux syscall link(2). @@ -46,7 +45,7 @@ func linkat(t *kernel.Task, olddirfd int32, oldpathAddr hostarch.Addr, newdirfd return linuxerr.EINVAL } if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) { - return syserror.ENOENT + return linuxerr.ENOENT } oldpath, err := copyInPath(t, oldpathAddr) @@ -320,7 +319,7 @@ func symlinkat(t *kernel.Task, targetAddr hostarch.Addr, newdirfd int32, linkpat return err } if len(target) == 0 { - return syserror.ENOENT + return linuxerr.ENOENT } linkpath, err := copyInPath(t, linkpathAddr) if err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go index 2bb783a85..38796d4db 100644 --- a/pkg/sentry/syscalls/linux/vfs2/path.go +++ b/pkg/sentry/syscalls/linux/vfs2/path.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) func copyInPath(t *kernel.Task, addr hostarch.Addr) (fspath.Path, error) { @@ -44,7 +43,7 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA if !path.Absolute { if !path.HasComponents() && !bool(shouldAllowEmptyPath) { root.DecRef(t) - return taskPathOperation{}, syserror.ENOENT + return taskPathOperation{}, linuxerr.ENOENT } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index 042aa4c97..204051cd0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -20,15 +20,14 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/waiter" - - "gvisor.dev/gvisor/pkg/hostarch" ) // fileCap is the maximum allowable files for poll & select. This has no @@ -189,7 +188,7 @@ func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration pfd[i].Events |= linux.POLLHUP | linux.POLLERR } remainingTimeout, n, err := pollBlock(t, pfd, timeout) - err = syserror.ConvertIntr(err, syserror.EINTR) + err = syserr.ConvertIntr(err, linuxerr.EINTR) // The poll entries are copied out regardless of whether // any are set or not. This aligns with the Linux behavior. @@ -299,7 +298,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Ad // Do the syscall, then count the number of bits set. if _, _, err = pollBlock(t, pfd, timeout); err != nil { - return 0, syserror.ConvertIntr(err, syserror.EINTR) + return 0, syserr.ConvertIntr(err, linuxerr.EINTR) } // r, w, and e are currently event mask bitsets; unset bits corresponding @@ -417,7 +416,7 @@ func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duratio nfds: nfds, timeout: remainingTimeout, }) - return 0, syserror.ERESTART_RESTARTBLOCK + return 0, linuxerr.ERESTART_RESTARTBLOCK } return n, err } @@ -464,7 +463,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { - err = syserror.ERESTARTNOHAND + err = linuxerr.ERESTARTNOHAND } return n, nil, err } @@ -494,7 +493,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { - err = syserror.ERESTARTNOHAND + err = linuxerr.ERESTARTNOHAND } return n, nil, err } @@ -541,7 +540,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { - err = syserror.ERESTARTNOHAND + err = linuxerr.ERESTARTNOHAND } return n, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index fe8aa06da..4e7dc5080 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -63,7 +62,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC n, err := read(t, file, dst, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "read", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "read", file) } // Readv implements Linux syscall readv(2). @@ -88,12 +87,12 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := read(t, file, dst, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "readv", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "readv", file) } func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { n, err := file.Read(t, dst, opts) - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { return n, err } @@ -115,14 +114,14 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt // "would block". n, err = file.Read(t, dst, opts) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -166,7 +165,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pread64", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "pread64", file) } // Preadv implements Linux syscall preadv(2). @@ -197,7 +196,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "preadv", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv", file) } // Preadv2 implements Linux syscall preadv2(2). @@ -243,12 +242,12 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err = pread(t, file, dst, offset, opts) } t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv2", file) } func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { n, err := file.PRead(t, dst, offset, opts) - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { return n, err } @@ -270,14 +269,14 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of // "would block". n, err = file.PRead(t, dst, offset+total, opts) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -314,7 +313,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := write(t, file, src, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "write", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "write", file) } // Writev implements Linux syscall writev(2). @@ -339,12 +338,12 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := write(t, file, src, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "writev", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "writev", file) } func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { n, err := file.Write(t, src, opts) - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { return n, err } @@ -366,14 +365,14 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op // "would block". n, err = file.Write(t, src, opts) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -416,7 +415,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwrite64", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "pwrite64", file) } // Pwritev implements Linux syscall pwritev(2). @@ -447,7 +446,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwritev", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev", file) } // Pwritev2 implements Linux syscall pwritev2(2). @@ -493,12 +492,12 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err = pwrite(t, file, src, offset, opts) } t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev2", file) } func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { n, err := file.PWrite(t, src, offset, opts) - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { return n, err } @@ -520,14 +519,14 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o // "would block". n, err = file.PWrite(t, src, offset+total, opts) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index b5a3b92c5..e608572b4 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX @@ -432,7 +431,7 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa start := root if !path.Absolute { if !path.HasComponents() && !bool(shouldAllowEmptyPath) { - return syserror.ENOENT + return linuxerr.ENOENT } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() @@ -465,7 +464,7 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa } func handleSetSizeError(t *kernel.Task, err error) error { - if err == syserror.ErrExceedsFileSizeLimit { + if err == linuxerr.ErrExceedsFileSizeLimit { // Convert error to EFBIG and send a SIGXFSZ per setrlimit(2). t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t)) return linuxerr.EFBIG diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 0c2e0720b..48be5a88d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -30,10 +31,7 @@ import ( slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/pkg/hostarch" ) // maxAddrLen is the maximum socket address length we're willing to accept. @@ -264,7 +262,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Capture address and call syscall implementation. @@ -274,7 +272,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0 - return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS) + return 0, nil, syserr.ConvertIntr(s.Connect(t, a, blocking).ToError(), linuxerr.ERESTARTSYS) } // accept is the implementation of the accept syscall. It is called by accept @@ -295,7 +293,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, syserror.ENOTSOCK + return 0, linuxerr.ENOTSOCK } // Call the syscall implementation for this socket, then copy the @@ -305,7 +303,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, peerRequested := addrLen != 0 nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS) } if peerRequested { // NOTE(magi): Linux does not give you an error if it can't @@ -354,7 +352,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Capture address and call syscall implementation. @@ -381,7 +379,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } if backlog > maxListenBacklog { @@ -419,7 +417,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Validate how, then call syscall implementation. @@ -450,7 +448,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Read the length. Reject negative values. @@ -531,7 +529,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } if optLen < 0 { @@ -569,7 +567,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Get the socket name and copy it to the caller. @@ -597,7 +595,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Get the socket peer name and copy it to the caller. @@ -630,7 +628,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Reject flags that we don't handle yet. @@ -687,7 +685,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { @@ -767,7 +765,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr hostarch.Addr, fl if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) if err != nil { - return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(err.ToError(), linuxerr.ERESTARTSYS) } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC @@ -789,7 +787,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr hostarch.Addr, fl } n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS) } defer cms.Release(t) @@ -852,7 +850,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, fla // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, syserror.ENOTSOCK + return 0, linuxerr.ENOTSOCK } if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { @@ -878,7 +876,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, fla n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) cm.Release(t) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) + return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS) } // Copy the address to the caller. @@ -925,7 +923,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Reject flags that we don't handle yet. @@ -967,7 +965,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, nil, syserror.ENOTSOCK + return 0, nil, linuxerr.ENOTSOCK } // Reject flags that we don't handle yet. @@ -1064,7 +1062,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) - err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file) + err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), linuxerr.ERESTARTSYS, "sendmsg", file) // Control messages should be released on error as well as for zero-length // messages, which are discarded by the receiver. if n == 0 || err != nil { @@ -1091,7 +1089,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) if !ok { - return 0, syserror.ENOTSOCK + return 0, linuxerr.ENOTSOCK } if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { @@ -1126,7 +1124,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) - return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file) + return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), linuxerr.ERESTARTSYS, "sendto", file) } // SendTo implements the linux syscall sendto(2). diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index d8009123f..0205f09e0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -151,7 +150,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal panic("at least one end of splice must be a pipe") } - if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + if n != 0 || err != linuxerr.ErrWouldBlock || nonBlock { break } if err = dw.waitForBoth(t); err != nil { @@ -173,7 +172,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // We can only pass a single file to handleIOError, so pick inFile arbitrarily. // This is used only for debugging purposes. - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "splice", outFile) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "splice", outFile) } // Tee implements Linux syscall tee(2). @@ -241,7 +240,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo defer dw.destroy() for { n, err = pipe.Tee(t, outPipeFD, inPipeFD, count) - if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + if n != 0 || err != linuxerr.ErrWouldBlock || nonBlock { break } if err = dw.waitForBoth(t); err != nil { @@ -251,7 +250,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if n != 0 { // If a partial write is completed, the error is dropped. Log it here. - if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { + if err != nil && err != io.EOF && err != linuxerr.ErrWouldBlock { log.Debugf("tee completed a partial write with error: %v", err) err = nil } @@ -259,7 +258,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // We can only pass a single file to handleIOError, so pick inFile arbitrarily. // This is used only for debugging purposes. - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "tee", inFile) } // Sendfile implements linux system call sendfile(2). @@ -360,10 +359,10 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc break } if err == nil && t.Interrupted() { - err = syserror.ErrInterrupted + err = linuxerr.ErrInterrupted break } - if err == syserror.ErrWouldBlock && !nonBlock { + if err == linuxerr.ErrWouldBlock && !nonBlock { err = dw.waitForBoth(t) } if err != nil { @@ -389,7 +388,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc var writeN int64 writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{}) wbuf = wbuf[writeN:] - if err == syserror.ErrWouldBlock && !nonBlock { + if err == linuxerr.ErrWouldBlock && !nonBlock { err = dw.waitForOut(t) } if err != nil { @@ -420,10 +419,10 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc break } if err == nil && t.Interrupted() { - err = syserror.ErrInterrupted + err = linuxerr.ErrInterrupted break } - if err == syserror.ErrWouldBlock && !nonBlock { + if err == linuxerr.ErrWouldBlock && !nonBlock { err = dw.waitForBoth(t) } if err != nil { @@ -441,7 +440,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } if total != 0 { - if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { + if err != nil && err != io.EOF && err != linuxerr.ErrWouldBlock { // If a partial write is completed, the error is dropped. Log it here. log.Debugf("sendfile completed a partial write with error: %v", err) err = nil @@ -450,7 +449,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // We can only pass a single file to handleIOError, so pick inFile arbitrarily. // This is used only for debugging purposes. - return uintptr(total), nil, slinux.HandleIOErrorVFS2(t, total != 0, err, syserror.ERESTARTSYS, "sendfile", inFile) + return uintptr(total), nil, slinux.HandleIOErrorVFS2(t, total != 0, err, linuxerr.ERESTARTSYS, "sendfile", inFile) } // dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go index ba1d30823..adaf8db3f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/stat.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // Stat implements Linux syscall stat(2). @@ -70,7 +69,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr hostarch.Addr, flag start := root if !path.Absolute { if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { - return syserror.ENOENT + return linuxerr.ENOENT } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() @@ -182,7 +181,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall start := root if !path.Absolute { if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { - return 0, nil, syserror.ENOENT + return 0, nil, linuxerr.ENOENT } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index d0ffc7c32..cfc693422 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -19,7 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/syserr" ) // Sync implements Linux syscall sync(2). @@ -108,12 +108,12 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel if flags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 && flags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 { t.Kernel().EmitUnimplementedEvent(t) - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS } if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 { if err := file.Sync(t); err != nil { - return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) + return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS) } } return 0, nil, nil diff --git a/pkg/sentry/syscalls/syscalls.go b/pkg/sentry/syscalls/syscalls.go index 511fb8b28..cfcc21271 100644 --- a/pkg/sentry/syscalls/syscalls.go +++ b/pkg/sentry/syscalls/syscalls.go @@ -31,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // Supported returns a syscall that is fully supported. @@ -103,10 +102,10 @@ func CapError(name string, c linux.Capability, note string, urls []string) kerne return 0, nil, linuxerr.EPERM } t.Kernel().EmitUnimplementedEvent(t) - return 0, nil, syserror.ENOSYS + return 0, nil, linuxerr.ENOSYS }, SupportLevel: kernel.SupportUnimplemented, - Note: fmt.Sprintf("%sReturns %q if the process does not have %s; %q otherwise.", note, linuxerr.EPERM, c.String(), syserror.ENOSYS), + Note: fmt.Sprintf("%sReturns %q if the process does not have %s; %q otherwise.", note, linuxerr.EPERM, c.String(), linuxerr.ENOSYS), URLs: urls, } } diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 36d999c47..c21971322 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -39,7 +39,6 @@ go_library( "//pkg/log", "//pkg/metric", "//pkg/sync", - "//pkg/syserror", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go index 3560e66ae..9b8c9a480 100644 --- a/pkg/sentry/time/sampler_arm64.go +++ b/pkg/sentry/time/sampler_arm64.go @@ -30,9 +30,9 @@ func getDefaultArchOverheadCycles() TSCValue { // frqRatio. defaultOverheadCycles of ARM equals to that on // x86 devided by frqRatio cntfrq := getCNTFRQ() - frqRatio := 1000000000 / cntfrq + frqRatio := 1000000000 / float64(cntfrq) overheadCycles := (1 * 1000) / frqRatio - return overheadCycles + return TSCValue(overheadCycles) } // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles. diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index a2032162d..914574543 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -116,7 +116,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/uniqueid", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", @@ -137,7 +136,6 @@ go_test( "//pkg/errors/linuxerr", "//pkg/sentry/contexttest", "//pkg/sync", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 5aad31b78..82ee2c521 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -1,9 +1,5 @@ # The gVisor Virtual Filesystem -THIS PACKAGE IS CURRENTLY EXPERIMENTAL AND NOT READY OR ENABLED FOR PRODUCTION -USE. For the filesystem implementation currently used by gVisor, see the `fs` -package. - ## Implementation Notes ### Reference Counting diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 23ccc6b66..fefd0fc9c 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -263,7 +262,7 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event num: num, }] if !ok { - return syserror.ENOENT + return linuxerr.ENOENT } // Update epi for the next call to ep.ReadEvents(). @@ -299,7 +298,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error num: num, }] if !ok { - return syserror.ENOENT + return linuxerr.ENOENT } // Unregister from the file so that epi will no longer be readied. diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index a875fdeca..452f5f1f9 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -17,6 +17,7 @@ package vfs import ( "bytes" "io" + "math" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -25,7 +26,6 @@ import ( fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -56,7 +56,7 @@ func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error { // StatFS implements FileDescriptionImpl.StatFS analogously to // super_operations::statfs == NULL in Linux. func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, error) { - return linux.Statfs{}, syserror.ENOSYS + return linux.Statfs{}, linuxerr.ENOSYS } // Allocate implements FileDescriptionImpl.Allocate analogously to @@ -175,27 +175,27 @@ type DirectoryFileDescriptionDefaultImpl struct{} // Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate. func (DirectoryFileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { - return syserror.EISDIR + return linuxerr.EISDIR } // PRead implements FileDescriptionImpl.PRead. func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // Read implements FileDescriptionImpl.Read. func (DirectoryFileDescriptionDefaultImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // PWrite implements FileDescriptionImpl.PWrite. func (DirectoryFileDescriptionDefaultImpl) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // Write implements FileDescriptionImpl.Write. func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { - return 0, syserror.EISDIR + return 0, linuxerr.EISDIR } // DentryMetadataFileDescriptionImpl may be embedded by implementations of @@ -368,7 +368,7 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src writable, ok := fd.data.(WritableDynamicBytesSource) if !ok { - return 0, syserror.EIO + return 0, linuxerr.EIO } n, err := writable.Write(ctx, src, offset) if err != nil { @@ -400,6 +400,9 @@ func (fd *DynamicBytesFileDescriptionImpl) Write(ctx context.Context, src userme // GenericConfigureMMap may be used by most implementations of // FileDescriptionImpl.ConfigureMMap. func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error { + if opts.Offset+opts.Length > math.MaxInt64 { + return linuxerr.EOVERFLOW + } opts.Mappable = m opts.MappingIdentity = fd fd.IncRef() diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 3423dede1..e34a8c11b 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -157,10 +156,10 @@ func TestGenCountFD(t *testing.T) { // Write and PWrite fails. if _, err := fd.Write(ctx, ioseq, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) { - t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) + t.Errorf("Write: got err %v, wanted %v", err, linuxerr.EIO) } if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) { - t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) + t.Errorf("Write: got err %v, wanted %v", err, linuxerr.EIO) } } diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 088beb8e2..17d94b341 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -209,7 +208,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt if i.events.Empty() { // Nothing to read yet, tell caller to block. - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } var writeLen int64 diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index cbe4d8c2d..1853cdca0 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -17,8 +17,8 @@ package vfs import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/syserror" ) // FileLocks supports POSIX and BSD style locks, which correspond to fcntl(2) @@ -47,9 +47,9 @@ func (fl *FileLocks) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerID i // Return an appropriate error for the unsuccessful lock attempt, depending on // whether this is a blocking or non-blocking operation. if block == nil { - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } - return syserror.ERESTARTSYS + return linuxerr.ERESTARTSYS } // UnlockBSD releases a BSD-style lock on the entire file. @@ -69,9 +69,9 @@ func (fl *FileLocks) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPI // Return an appropriate error for the unsuccessful lock attempt, depending on // whether this is a blocking or non-blocking operation. if block == nil { - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } - return syserror.ERESTARTSYS + return linuxerr.ERESTARTSYS } // UnlockPOSIX releases a POSIX-style lock on a file region. diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 4d6b59a26..05a416775 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" ) // A Mount is a replacement of a Dentry (Mount.key.point) from one Filesystem @@ -225,7 +224,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr vdDentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef(ctx) - return syserror.ENOENT + return linuxerr.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and // vfs.mountMu.Lock(). diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index e4da15009..7cc68a157 100644 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go @@ -16,9 +16,9 @@ package vfs import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) var fspathBuilderPool = sync.Pool{ @@ -137,7 +137,7 @@ loop: // Linux's sys_getcwd(). func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) { if vd.dentry.IsDead() { - return "", syserror.ENOENT + return "", linuxerr.ENOENT } b := getFSPathBuilder() diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index 4744514bd..953d31876 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/syserror" ) // AccessTypes is a bitmask of Unix file permissions. @@ -195,7 +194,7 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOpt return err } if limit < int64(stat.Size) { - return syserror.ErrExceedsFileSizeLimit + return linuxerr.ErrExceedsFileSizeLimit } } if stat.Mask&linux.STATX_MODE != 0 { @@ -282,7 +281,7 @@ func CheckLimit(ctx context.Context, offset, size int64) (int64, error) { return size, nil } if offset >= int64(fileSizeLimit) { - return 0, syserror.ErrExceedsFileSizeLimit + return 0, linuxerr.ErrExceedsFileSizeLimit } remaining := int64(fileSizeLimit) - offset if remaining < size { diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 6f58f33ce..40aff2927 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // ResolvingPath represents the state of an in-progress path resolution, shared @@ -224,6 +223,12 @@ func (rp *ResolvingPath) Final() bool { return rp.curPart == 0 && !rp.pit.NextOk() } +// Pit returns a copy of rp's current path iterator. Modifying the iterator +// does not change rp. +func (rp *ResolvingPath) Pit() fspath.Iterator { + return rp.pit +} + // Component returns the current path component in the stream represented by // rp. // @@ -331,7 +336,7 @@ func (rp *ResolvingPath) HandleSymlink(target string) error { return linuxerr.ELOOP } if len(target) == 0 { - return syserror.ENOENT + return linuxerr.ENOENT } rp.symlinks++ targetPath := fspath.Parse(target) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index eb3c60610..1b2a668c0 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -48,7 +48,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts. @@ -281,7 +280,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential if newpop.Path.Absolute { return linuxerr.EEXIST } - return syserror.ENOENT + return linuxerr.ENOENT } if newpop.FollowFinalSymlink { oldVD.DecRef(ctx) @@ -318,7 +317,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia if pop.Path.Absolute { return linuxerr.EEXIST } - return syserror.ENOENT + return linuxerr.ENOENT } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.MkdirAt: file creation paths can't follow final symlink") @@ -348,7 +347,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia } // MknodAt creates a file of the given mode at the given path. It returns an -// error from the syserror package. +// error from the linuxerr package. func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { if !pop.Path.Begin.Ok() { // pop.Path should not be empty in operations that create/delete files. @@ -356,7 +355,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia if pop.Path.Absolute { return linuxerr.EEXIST } - return syserror.ENOENT + return linuxerr.ENOENT } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.MknodAt: file creation paths can't follow final symlink") @@ -494,7 +493,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti if oldpop.Path.Absolute { return linuxerr.EBUSY } - return syserror.ENOENT + return linuxerr.ENOENT } if oldpop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.RenameAt: source path can't follow final symlink") @@ -515,7 +514,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti if newpop.Path.Absolute { return linuxerr.EBUSY } - return syserror.ENOENT + return linuxerr.ENOENT } if newpop.FollowFinalSymlink { oldParentVD.DecRef(ctx) @@ -556,7 +555,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia if pop.Path.Absolute { return linuxerr.EBUSY } - return syserror.ENOENT + return linuxerr.ENOENT } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.RmdirAt: file deletion paths can't follow final symlink") @@ -639,7 +638,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent if pop.Path.Absolute { return linuxerr.EEXIST } - return syserror.ENOENT + return linuxerr.ENOENT } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.SymlinkAt: file creation paths can't follow final symlink") @@ -673,7 +672,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti if pop.Path.Absolute { return linuxerr.EBUSY } - return syserror.ENOENT + return linuxerr.ENOENT } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.UnlinkAt: file deletion paths can't follow final symlink") diff --git a/pkg/shim/runtimeoptions/runtimeoptions.go b/pkg/shim/runtimeoptions/runtimeoptions.go index 072dd87f0..e76d73ea7 100644 --- a/pkg/shim/runtimeoptions/runtimeoptions.go +++ b/pkg/shim/runtimeoptions/runtimeoptions.go @@ -15,3 +15,10 @@ // Package runtimeoptions contains the runtimeoptions proto. package runtimeoptions + +import proto "github.com/gogo/protobuf/proto" + +func init() { + // TODO(gvisor.dev/issue/6449): Upgrade runtimeoptions.proto after upgrading to containerd 1.5 + proto.RegisterType((*Options)(nil), "runtimeoptions.v1.Options") +} diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sighandling/BUILD index 1790d57c9..72f10f982 100644 --- a/pkg/sentry/sighandling/BUILD +++ b/pkg/sighandling/BUILD @@ -8,7 +8,7 @@ go_library( "sighandling.go", "sighandling_unsafe.go", ], - visibility = ["//pkg/sentry:internal"], + visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sighandling/sighandling.go index bdaf8af29..bdaf8af29 100644 --- a/pkg/sentry/sighandling/sighandling.go +++ b/pkg/sighandling/sighandling.go diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sighandling/sighandling_unsafe.go index 3fe5c6770..7deeda042 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sighandling/sighandling_unsafe.go @@ -15,6 +15,7 @@ package sighandling import ( + "fmt" "unsafe" "golang.org/x/sys/unix" @@ -37,3 +38,36 @@ func IgnoreChildStop() error { return nil } + +// ReplaceSignalHandler replaces the existing signal handler for the provided +// signal with the function pointer at `handler`. This bypasses the Go runtime +// signal handlers, and should only be used for low-level signal handlers where +// use of signal.Notify is not appropriate. +// +// It stores the value of the previously set handler in previous. +func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { + var sa linux.SigAction + const maskLen = 8 + + // Get the existing signal handler information, and save the current + // handler. Once we replace it, we will use this pointer to fall back to + // it when we receive other signals. + if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { + return e + } + + // Fail if there isn't a previous handler. + if sa.Handler == 0 { + return fmt.Errorf("previous handler for signal %x isn't set", sig) + } + + *previous = uintptr(sa.Handler) + + // Install our own handler. + sa.Handler = uint64(handler) + if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { + return e + } + + return nil +} diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go index 82b6df18c..7b9c2a4db 100644 --- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go +++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go @@ -37,6 +37,8 @@ func (p *AtomicPtr) loadPtr(v *Value) { // Load returns the value set by the most recent Store. It returns nil if there // has been no previous call to Store. +// +//go:nosplit func (p *AtomicPtr) Load() *Value { return (*Value)(atomic.LoadPointer(&p.ptr)) } diff --git a/pkg/sync/atomicptrmap/BUILD b/pkg/sync/atomicptrmap/BUILD index b0e218c79..1c0085d69 100644 --- a/pkg/sync/atomicptrmap/BUILD +++ b/pkg/sync/atomicptrmap/BUILD @@ -19,6 +19,7 @@ go_template( "Key", "Value", ], + visibility = ["//:sandbox"], deps = [ "//pkg/gohacks", "//pkg/sync", diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD index ceee494fc..d8c4c9613 100644 --- a/pkg/syserr/BUILD +++ b/pkg/syserr/BUILD @@ -14,7 +14,7 @@ go_library( "//pkg/abi/linux/errno", "//pkg/errors", "//pkg/errors/linuxerr", - "//pkg/syserror", + "//pkg/safecopy", "//pkg/tcpip", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go index 558240008..b679f3046 100644 --- a/pkg/syserr/syserr.go +++ b/pkg/syserr/syserr.go @@ -24,7 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/errors" "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/safecopy" ) // Error represents an internal error. @@ -52,12 +52,12 @@ func New(message string, linuxTranslation errno.Errno) *Error { } e := error(unix.Errno(err.errno)) - // syserror.ErrWouldBlock gets translated to linuxerr.EWOULDBLOCK and + // linuxerr.ErrWouldBlock gets translated to linuxerr.EWOULDBLOCK and // enables proper blocking semantics. This should temporary address the // class of blocking bugs that keep popping up with the current state of // the error space. if err.errno == linuxerr.EWOULDBLOCK.Errno() { - e = syserror.ErrWouldBlock + e = linuxerr.ErrWouldBlock } linuxBackwardsTranslations[err.errno] = linuxBackwardsTranslation{err: e, ok: true} @@ -279,16 +279,25 @@ func FromError(err error) *Error { if err == nil { return nil } - if errno, ok := err.(unix.Errno); ok { - return FromHost(errno) - } - if linuxErr, ok := err.(*errors.Error); ok { - return FromHost(unix.Errno(linuxErr.Errno())) + switch e := err.(type) { + case unix.Errno: + return FromHost(e) + case *errors.Error: + return FromHost(unix.Errno(e.Errno())) + case safecopy.SegvError, safecopy.BusError, safecopy.AlignmentError: + return FromHost(unix.EFAULT) } - if errno, ok := syserror.TranslateError(err); ok { - return FromHost(errno) + msg := fmt.Sprintf("err: %s type: %T", err.Error(), err) + panic(msg) +} + +// ConvertIntr converts the provided error code (err) to another one (intr) if +// the first error corresponds to an interrupted operation. +func ConvertIntr(err, intr error) error { + if err == linuxerr.ErrInterrupted { + return intr } - panic("unknown error: " + err.Error()) + return err } diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go deleted file mode 100644 index b24edb364..000000000 --- a/pkg/syserror/syserror.go +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package syserror contains syscall error codes exported as error interface -// instead of Errno. This allows for fast comparison and returns when the -// comparand or return value is of type error because there is no need to -// convert from Errno to an interface, i.e., runtime.convT2I isn't called. -package syserror - -import ( - "errors" - - "golang.org/x/sys/unix" -) - -// The following variables have the same meaning as their syscall equivalent. -var ( - EIDRM = error(unix.EIDRM) - EINTR = error(unix.EINTR) - EIO = error(unix.EIO) - EISDIR = error(unix.EISDIR) - ENOENT = error(unix.ENOENT) - ENOEXEC = error(unix.ENOEXEC) - ENOMEM = error(unix.ENOMEM) - ENOTSOCK = error(unix.ENOTSOCK) - ENOSPC = error(unix.ENOSPC) - ENOSYS = error(unix.ENOSYS) -) - -var ( - // ErrWouldBlock is an internal error used to indicate that an operation - // cannot be satisfied immediately, and should be retried at a later - // time, possibly when the caller has received a notification that the - // operation may be able to complete. It is used by implementations of - // the kio.File interface. - ErrWouldBlock = errors.New("request would block") - - // ErrInterrupted is returned if a request is interrupted before it can - // complete. - ErrInterrupted = errors.New("request was interrupted") - - // ErrExceedsFileSizeLimit is returned if a request would exceed the - // file's size limit. - ErrExceedsFileSizeLimit = errors.New("exceeds file size limit") -) - -// errorMap is the map used to convert generic errors into errnos. -var errorMap = map[error]unix.Errno{} - -// errorUnwrappers is an array of unwrap functions to extract typed errors. -var errorUnwrappers = []func(error) (unix.Errno, bool){} - -// AddErrorTranslation allows modules to populate the error map by adding their -// own translations during initialization. Returns if the error translation is -// accepted or not. A pre-existing translation will not be overwritten by the -// new translation. -func AddErrorTranslation(from error, to unix.Errno) bool { - if _, ok := errorMap[from]; ok { - return false - } - - errorMap[from] = to - return true -} - -// AddErrorUnwrapper registers an unwrap method that can extract a concrete error -// from a typed, but not initialized, error. -func AddErrorUnwrapper(unwrap func(e error) (unix.Errno, bool)) { - errorUnwrappers = append(errorUnwrappers, unwrap) -} - -// TranslateError translates errors to errnos, it will return false if -// the error was not registered. -func TranslateError(from error) (unix.Errno, bool) { - if err, ok := errorMap[from]; ok { - return err, true - } - // Try to unwrap the error if we couldn't match an error - // exactly. This might mean that a package has its own - // error type. - for _, unwrap := range errorUnwrappers { - if err, ok := unwrap(from); ok { - return err, true - } - } - return 0, false -} - -// ConvertIntr converts the provided error code (err) to another one (intr) if -// the first error corresponds to an interrupted operation. -func ConvertIntr(err, intr error) error { - if err == ErrInterrupted { - return intr - } - return err -} - -// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel -// include/linux/errno.h. These errnos are never returned to userspace -// directly, but are used to communicate the expected behavior of an -// interrupted syscall from the syscall to signal handling. -type SyscallRestartErrno int - -// These numeric values are significant because ptrace syscall exit tracing can -// observe them. -// -// For all of the following errnos, if the syscall is not interrupted by a -// signal delivered to a user handler, the syscall is restarted. -const ( - // ERESTARTSYS is returned by an interrupted syscall to indicate that it - // should be converted to EINTR if interrupted by a signal delivered to a - // user handler without SA_RESTART set, and restarted otherwise. - ERESTARTSYS = SyscallRestartErrno(512) - - // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it - // should always be restarted. - ERESTARTNOINTR = SyscallRestartErrno(513) - - // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it - // should be converted to EINTR if interrupted by a signal delivered to a - // user handler, and restarted otherwise. - ERESTARTNOHAND = SyscallRestartErrno(514) - - // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate - // that it should be restarted using a custom function. The interrupted - // syscall must register a custom restart function by calling - // Task.SetRestartSyscallFn. - ERESTART_RESTARTBLOCK = SyscallRestartErrno(516) -) - -// Error implements error.Error. -func (e SyscallRestartErrno) Error() string { - // Descriptions are borrowed from strace. - switch e { - case ERESTARTSYS: - return "to be restarted if SA_RESTART is set" - case ERESTARTNOINTR: - return "to be restarted" - case ERESTARTNOHAND: - return "to be restarted if no handler" - case ERESTART_RESTARTBLOCK: - return "interrupted by signal" - default: - return "(unknown interrupt error)" - } -} - -// SyscallRestartErrnoFromReturn returns the SyscallRestartErrno represented by -// rv, the value in a syscall return register. -func SyscallRestartErrnoFromReturn(rv uintptr) (SyscallRestartErrno, bool) { - switch int(rv) { - case -int(ERESTARTSYS): - return ERESTARTSYS, true - case -int(ERESTARTNOINTR): - return ERESTARTNOINTR, true - case -int(ERESTARTNOHAND): - return ERESTARTNOHAND, true - case -int(ERESTART_RESTARTBLOCK): - return ERESTART_RESTARTBLOCK, true - default: - return 0, false - } -} - -func init() { - AddErrorTranslation(ErrWouldBlock, unix.EWOULDBLOCK) - AddErrorTranslation(ErrInterrupted, unix.EINTR) - AddErrorTranslation(ErrExceedsFileSizeLimit, unix.EFBIG) -} diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index f00cfd0f5..b98de54c5 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -25,6 +25,7 @@ go_library( "stdclock.go", "stdclock_state.go", "tcpip.go", + "tcpip_state.go", "timer.go", ], visibility = ["//visibility:public"], @@ -69,7 +70,6 @@ deps_test( "//pkg/tcpip/header", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/packetsocket", "//pkg/tcpip/link/qdisc/fifo", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 48b24692b..c8460e63c 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) { addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { @@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } done := make(chan struct{}) @@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue @@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue @@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) { ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip1.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err) + } ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip2.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err) + } done := make(chan struct{}) fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { @@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) { addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { @@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) { ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip1.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err) + } ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip2.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err) + } c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) if err != nil { @@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) if err != nil { @@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ip.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %w", NICID, protocolAddr, err) + } l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) + return nil, nil, nil, fmt.Errorf("NewListener: %w", err) } c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) if err != nil { l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) + return nil, nil, nil, fmt.Errorf("DialTCP: %w", err) } c2, err = l.Accept() if err != nil { l.Close() c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) + return nil, nil, nil, fmt.Errorf("l.Accept: %w", err) } stop = func() { @@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { if err := l.Close(); err != nil { stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) + return nil, nil, nil, fmt.Errorf("l.Close(): %w", err) } return c1, c2, stop, nil @@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) { }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.Addr.WithPrefix(), + } + if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err) + } fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { time.Sleep(time.Second) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index e0dfe5813..24c2c3e6b 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -324,6 +324,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field +// in ControlMessages. +func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPv6PacketInfo { + t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo) + } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" { + t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff) + } + } +} + // ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress // field in ControlMessages. func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { @@ -729,7 +742,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp return } l := int(opts[i+1]) - if i < 2 || i+l > limit { + if l < 2 || i+l > limit { return } i += l diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index 95ade0e5c..1f18213e5 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -49,9 +49,9 @@ const ( // EthernetAddressSize is the size, in bytes, of an ethernet address. EthernetAddressSize = 6 - // unspecifiedEthernetAddress is the unspecified ethernet address + // UnspecifiedEthernetAddress is the unspecified ethernet address // (all bits set to 0). - unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") // EthernetBroadcastAddress is an ethernet address that addresses every node // on a local link. @@ -134,7 +134,7 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { return false } - if addr == unspecifiedEthernetAddress { + if addr == UnspecifiedEthernetAddress { return false } diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index bf9ccbf1a..adc04e855 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -44,7 +44,7 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) { }, { "Unspecified", - unspecifiedEthernetAddress, + UnspecifiedEthernetAddress, false, }, { @@ -91,7 +91,7 @@ func TestIsMulticastEthernetAddress(t *testing.T) { }, { "Unspecified", - unspecifiedEthernetAddress, + UnspecifiedEthernetAddress, false, }, { diff --git a/pkg/tcpip/internal/tcp/BUILD b/pkg/tcpip/internal/tcp/BUILD new file mode 100644 index 000000000..9ae258a0b --- /dev/null +++ b/pkg/tcpip/internal/tcp/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "tcp", + srcs = ["tcp.go"], + visibility = ["//pkg/tcpip:__subpackages__"], + deps = [ + "//pkg/tcpip", + ], +) diff --git a/pkg/tcpip/internal/tcp/tcp.go b/pkg/tcpip/internal/tcp/tcp.go new file mode 100644 index 000000000..0616d368c --- /dev/null +++ b/pkg/tcpip/internal/tcp/tcp.go @@ -0,0 +1,48 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcp contains internal type definitions that are not expected to be +// used by anyone else outside pkg/tcpip. +package tcp + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// TSOffset is an offset applied to the value of the TSVal field in the TCP +// Timestamp option. +// +// +stateify savable +type TSOffset struct { + milliseconds uint32 +} + +// NewTSOffset creates a new TSOffset from milliseconds. +func NewTSOffset(milliseconds uint32) TSOffset { + return TSOffset{ + milliseconds: milliseconds, + } +} + +// TSVal applies the offset to now and returns the timestamp in milliseconds. +func (offset TSOffset) TSVal(now tcpip.MonotonicTime) uint32 { + return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + offset.milliseconds +} + +// Elapsed calculates the elapsed time given now and the echoed back timestamp. +func (offset TSOffset) Elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration { + return time.Duration(offset.TSVal(now)-tsEcr) * time.Millisecond +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index f26c857eb..658557d62 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -28,7 +28,9 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Pkt *stack.PacketBuffer + Pkt *stack.PacketBuffer + + // TODO(https://gvisor.dev/issue/6537): Remove these fields. Proto tcpip.NetworkProtocolNumber Route stack.RouteInfo } @@ -244,7 +246,10 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol Route: r, } - e.q.Write(p) + // Write returns false if the queue is full. A full queue is not an error + // from the perspective of a LinkEndpoint so we ignore Write's return + // value and always return nil from this method. + _ = e.q.Write(p) return nil } @@ -290,3 +295,18 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { // AddHeader implements stack.LinkEndpoint.AddHeader. func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + p := PacketInfo{ + Pkt: pkt, + Proto: pkt.NetworkProtocolNumber, + } + + // Write returns false if the queue is full. A full queue is not an error + // from the perspective of a LinkEndpoint so we ignore Write's return + // value and always return nil from this method. + _ = e.q.Write(p) + + return nil +} diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index b427c6170..8211a2031 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -42,6 +42,14 @@ type Endpoint struct { nested.Endpoint } +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + if l := e.Endpoint.LinkAddress(); len(l) != 0 { + return l + } + return header.UnspecifiedEthernetAddress +} + // DeliverNetworkPacket implements stack.NetworkDispatcher. func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) @@ -57,18 +65,22 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP // Capabilities implements stack.LinkEndpoint. func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities() + c := e.Endpoint.Capabilities() + if c&stack.CapabilityLoopback == 0 { + c |= stack.CapabilityResolutionRequired + } + return c } // WritePacket implements stack.LinkEndpoint. func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + e.AddHeader(e.LinkAddress(), r.RemoteLinkAddress, proto, pkt) return e.Endpoint.WritePacket(r, proto, pkt) } // WritePackets implements stack.LinkEndpoint. func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - linkAddr := e.Endpoint.LinkAddress() + linkAddr := e.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) @@ -83,7 +95,10 @@ func (e *Endpoint) MaxHeaderLength() uint16 { } // ARPHardwareType implements stack.LinkEndpoint. -func (*Endpoint) ARPHardwareType() header.ARPHardwareType { +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + if a := e.Endpoint.ARPHardwareType(); a != header.ARPHardwareNone { + return a + } return header.ARPHardwareEther } @@ -97,3 +112,8 @@ func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkP } eth.Encode(&fields) } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + return e.Endpoint.WriteRawPacket(pkt) +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 48356c343..058242f96 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -505,6 +505,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net } } +// WriteRawPacket implements stack.LinkEndpoint. +func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } + // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 7012d8829..ca1f9c08d 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,19 +76,8 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - // Construct data as the unparsed portion for the loopback packet. - data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - - // Because we're immediately turning around and writing the packet back - // to the rx path, we intentionally don't preserve the remote and local - // link addresses from the stack.Route we're passed. - newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: data, - }) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt) - - return nil +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + return e.WriteRawPacket(pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -103,3 +92,19 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + // Construct data as the unparsed portion for the loopback packet. + data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + + // Because we're immediately turning around and writing the packet back + // to the rx path, we intentionally don't preserve the remote and local + // link addresses from the stack.Route we're passed. + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, + }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, pkt.NetworkProtocolNumber, newPkt) + + return nil +} diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 3e2a1aa94..844f5959b 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -131,6 +131,11 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { } +// WriteRawPacket implements stack.LinkEndpoint. +func (*InjectableEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 3e816b0c7..83a6c1cc8 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -60,16 +60,6 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco } } -// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. -func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.mu.RLock() - d := e.dispatcher - e.mu.RUnlock() - if d != nil { - d.DeliverOutboundPacket(remote, local, protocol, pkt) - } -} - // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.mu.Lock() @@ -152,3 +142,8 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.child.AddHeader(local, remote, protocol, pkt) } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + return e.child.WriteRawPacket(pkt) +} diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD deleted file mode 100644 index 6fff160ce..000000000 --- a/pkg/tcpip/link/packetsocket/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "packetsocket", - srcs = ["endpoint.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go deleted file mode 100644 index e01837e2d..000000000 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package packetsocket provides a link layer endpoint that provides the ability -// to loop outbound packets to any AF_PACKET sockets that may be interested in -// the outgoing packet. -package packetsocket - -import ( - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/link/nested" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type endpoint struct { - nested.Endpoint -} - -// New creates a new packetsocket LinkEndpoint. -func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - e := &endpoint{} - e.Endpoint.Init(lower, e) - return e -} - -// WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) - return e.Endpoint.WritePacket(r, protocol, pkt) -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) - } - - return e.Endpoint.WritePackets(r, pkts, proto) -} diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 5030b6ba1..3ed0aa3fe 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -121,3 +121,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { // AddHeader implements stack.LinkEndpoint. func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { } + +// WriteRawPacket implements stack.LinkEndpoint. +func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 40bd5560b..b41e3e2fa 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -108,11 +108,6 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } -// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. -func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) -} - // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { // nil means the NIC is being removed. @@ -228,3 +223,8 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.lower.AddHeader(local, remote, protocol, pkt) } + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + return e.lower.WriteRawPacket(pkt) +} diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s index 298bad55d..f2c230720 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s +++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s @@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 MOVQ $0x0, R10 // sigmask parameter which isn't used here MOVQ $0x10f, AX // SYS_PPOLL SYSCALL - CMPQ AX, $0xfffffffffffff001 + CMPQ AX, $0xfffffffffffff002 JLS ok MOVQ $-1, n+24(FP) NEGQ AX diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s index b62888b93..8807586c7 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s +++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s @@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 MOVD $0x0, R3 // sigmask parameter which isn't used here MOVD $0x49, R8 // SYS_PPOLL SVC - CMP $0xfffffffffffff001, R0 + CMP $0xfffffffffffff002, R0 BLS ok MOVD $-1, R1 MOVD R1, n+24(FP) diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index e76fc55b6..e53789d92 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -152,10 +152,22 @@ type PollEvent struct { // no data is available, it will block in a poll() syscall until the file // descriptor becomes readable. func BlockingRead(fd int, b []byte) (int, tcpip.Error) { + n, err := BlockingReadUntranslated(fd, b) + if err != 0 { + return n, TranslateErrno(err) + } + return n, nil +} + +// BlockingReadUntranslated reads from a file descriptor that is set up as +// non-blocking. If no data is available, it will block in a poll() syscall +// until the file descriptor becomes readable. It returns the raw unix.Errno +// value returned by the underlying syscalls. +func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) { for { n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) if e == 0 { - return int(n), nil + return int(n), 0 } event := PollEvent{ @@ -165,7 +177,7 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) { _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != unix.EINTR { - return 0, TranslateErrno(e) + return 0, e } } } @@ -181,7 +193,9 @@ func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip if e == 0 { return int(n), nil } - + if e != 0 && e != unix.EWOULDBLOCK { + return 0, TranslateErrno(e) + } stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) if stopped { return -1, nil @@ -204,6 +218,10 @@ func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpi return int(n), nil } + if e != 0 && e != unix.EWOULDBLOCK { + return 0, TranslateErrno(e) + } + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) if stopped { return -1, nil @@ -228,5 +246,13 @@ func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) }, } _, errno := BlockingPoll(&pevents[0], len(pevents), nil) + if errno != 0 { + return pevents[0].Revents&unix.POLLIN != 0, errno + } + + if pevents[1].Revents&unix.POLLHUP != 0 || pevents[1].Revents&unix.POLLERR != 0 { + errno = unix.ECONNRESET + } + return pevents[0].Revents&unix.POLLIN != 0, errno } diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 4215ee852..f8076d83c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -5,19 +5,26 @@ package(licenses = ["notice"]) go_library( name = "sharedmem", srcs = [ + "queuepair.go", "rx.go", + "server_rx.go", + "server_tx.go", "sharedmem.go", + "sharedmem_server.go", "sharedmem_unsafe.go", "tx.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/cleanup", + "//pkg/eventfd", "//pkg/log", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", + "//pkg/tcpip/link/sharedmem/pipe", "//pkg/tcpip/link/sharedmem/queue", "//pkg/tcpip/stack", "@org_golang_x_sys//unix:go_default_library", @@ -26,9 +33,7 @@ go_library( go_test( name = "sharedmem_test", - srcs = [ - "sharedmem_test.go", - ], + srcs = ["sharedmem_test.go"], library = ":sharedmem", deps = [ "//pkg/sync", @@ -41,3 +46,22 @@ go_test( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "sharedmem_server_test", + size = "small", + srcs = ["sharedmem_server_test.go"], + deps = [ + ":sharedmem", + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/header", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go index 696e6c9e5..a78826ebc 100644 --- a/pkg/tcpip/link/sharedmem/queue/rx.go +++ b/pkg/tcpip/link/sharedmem/queue/rx.go @@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool { } r.tx.Flush() - return true } @@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool { func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) { for { outBufs := bufs - // Pull the next descriptor from the rx pipe. b := r.rx.Pull() if b == nil { diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go new file mode 100644 index 000000000..b12647fdd --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queuepair.go @@ -0,0 +1,199 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "fmt" + "io/ioutil" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" +) + +const ( + // defaultQueueDataSize is the size of the shared memory data region that + // holds the scatter/gather buffers. + defaultQueueDataSize = 1 << 20 // 1MiB + + // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors. + // + // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU) + // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data + // area. Which means the pipe needs to be big enough to hold 819 + // descriptors. + // + // Each descriptor is approximately 8 (slot descriptor in pipe) + + // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is + // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.) + // + // Which means we need approximately 36*819 ~ 29 KiB to store all packet + // descriptors. We could go with a 32 KiB pipe but to give it some slack in + // how the upper layer may make use of the scatter gather buffers we double + // this to hold enough descriptors. + defaultQueuePipeSize = 64 << 10 // 64KiB + + // defaultSharedDataSize is the size of the sharedData region used to + // enable/disable notifications. + defaultSharedDataSize = 4 << 10 // 4KiB +) + +// A QueuePair represents a pair of TX/RX queues. +type QueuePair struct { + // txCfg is the QueueConfig to be used for transmit queue. + txCfg QueueConfig + + // rxCfg is the QueueConfig to be used for receive queue. + rxCfg QueueConfig +} + +// NewQueuePair creates a shared memory QueuePair. +func NewQueuePair() (*QueuePair, error) { + txCfg, err := createQueueFDs(queueSizes{ + dataSize: defaultQueueDataSize, + txPipeSize: defaultQueuePipeSize, + rxPipeSize: defaultQueuePipeSize, + sharedDataSize: defaultSharedDataSize, + }) + + if err != nil { + return nil, fmt.Errorf("failed to create tx queue: %s", err) + } + + rxCfg, err := createQueueFDs(queueSizes{ + dataSize: defaultQueueDataSize, + txPipeSize: defaultQueuePipeSize, + rxPipeSize: defaultQueuePipeSize, + sharedDataSize: defaultSharedDataSize, + }) + + if err != nil { + closeFDs(txCfg) + return nil, fmt.Errorf("failed to create rx queue: %s", err) + } + + return &QueuePair{ + txCfg: txCfg, + rxCfg: rxCfg, + }, nil +} + +// Close closes underlying tx/rx queue fds. +func (q *QueuePair) Close() { + closeFDs(q.txCfg) + closeFDs(q.rxCfg) +} + +// TXQueueConfig returns the QueueConfig for the receive queue. +func (q *QueuePair) TXQueueConfig() QueueConfig { + return q.txCfg +} + +// RXQueueConfig returns the QueueConfig for the transmit queue. +func (q *QueuePair) RXQueueConfig() QueueConfig { + return q.rxCfg +} + +type queueSizes struct { + dataSize int64 + txPipeSize int64 + rxPipeSize int64 + sharedDataSize int64 +} + +func createQueueFDs(s queueSizes) (QueueConfig, error) { + success := false + var eventFD eventfd.Eventfd + var dataFD, txPipeFD, rxPipeFD, sharedDataFD int + defer func() { + if success { + return + } + closeFDs(QueueConfig{ + EventFD: eventFD, + DataFD: dataFD, + TxPipeFD: txPipeFD, + RxPipeFD: rxPipeFD, + SharedDataFD: sharedDataFD, + }) + }() + eventFD, err := eventfd.Create() + if err != nil { + return QueueConfig{}, fmt.Errorf("eventfd failed: %v", err) + } + dataFD, err = createFile(s.dataSize, false) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err) + } + txPipeFD, err = createFile(s.txPipeSize, true) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err) + } + rxPipeFD, err = createFile(s.rxPipeSize, true) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err) + } + sharedDataFD, err = createFile(s.sharedDataSize, false) + if err != nil { + return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err) + } + success = true + return QueueConfig{ + EventFD: eventFD, + DataFD: dataFD, + TxPipeFD: txPipeFD, + RxPipeFD: rxPipeFD, + SharedDataFD: sharedDataFD, + }, nil +} + +func createFile(size int64, initQueue bool) (fd int, err error) { + const tmpDir = "/dev/shm/" + f, err := ioutil.TempFile(tmpDir, "sharedmem_test") + if err != nil { + return -1, fmt.Errorf("TempFile failed: %v", err) + } + defer f.Close() + unix.Unlink(f.Name()) + + if initQueue { + // Write the "slot-free" flag in the initial queue. + if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil { + return -1, fmt.Errorf("WriteAt failed: %v", err) + } + } + + fd, err = unix.Dup(int(f.Fd())) + if err != nil { + return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err) + } + + if err := unix.Ftruncate(fd, size); err != nil { + unix.Close(fd) + return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err) + } + + return fd, nil +} + +func closeFDs(c QueueConfig) { + unix.Close(c.DataFD) + c.EventFD.Close() + unix.Close(c.TxPipeFD) + unix.Close(c.RxPipeFD) + unix.Close(c.SharedDataFD) +} diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go index e882a128c..87747dcc7 100644 --- a/pkg/tcpip/link/sharedmem/rx.go +++ b/pkg/tcpip/link/sharedmem/rx.go @@ -21,7 +21,7 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -30,7 +30,7 @@ type rx struct { data []byte sharedData []byte q queue.Rx - eventFD int + eventFD eventfd.Eventfd } // init initializes all state needed by the rx queue based on the information @@ -68,7 +68,7 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error { // Duplicate the eventFD so that caller can close it but we can still // use it. - efd, err := unix.Dup(c.EventFD) + efd, err := c.EventFD.Dup() if err != nil { unix.Munmap(txPipe) unix.Munmap(rxPipe) @@ -77,16 +77,6 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error { return err } - // Set the eventfd as non-blocking. - if err := unix.SetNonblock(efd, true); err != nil { - unix.Munmap(txPipe) - unix.Munmap(rxPipe) - unix.Munmap(data) - unix.Munmap(sharedData) - unix.Close(efd) - return err - } - // Initialize state based on buffers. r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData)) r.data = data @@ -105,7 +95,13 @@ func (r *rx) cleanup() { unix.Munmap(r.data) unix.Munmap(r.sharedData) - unix.Close(r.eventFD) + r.eventFD.Close() +} + +// notify writes to the tx.eventFD to indicate to the peer that there is data to +// be read. +func (r *rx) notify() { + r.eventFD.Notify() } // postAndReceive posts the provided buffers (if any), and then tries to read @@ -122,8 +118,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue. if len(b) != 0 && !r.q.PostBuffers(b) { r.q.EnableNotification() for !r.q.PostBuffers(b) { - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) + r.eventFD.Wait() if atomic.LoadUint32(stopRequested) != 0 { r.q.DisableNotification() return nil, 0 @@ -147,8 +142,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue. } // Wait for notification. - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) + r.eventFD.Wait() if atomic.LoadUint32(stopRequested) != 0 { r.q.DisableNotification() return nil, 0 diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go new file mode 100644 index 000000000..6ea21ffd1 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/server_rx.go @@ -0,0 +1,142 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/eventfd" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +type serverRx struct { + // packetPipe represents the receive end of the pipe that carries the packet + // descriptors sent by the client. + packetPipe pipe.Rx + + // completionPipe represents the transmit end of the pipe that will carry + // completion notifications from the server to the client. + completionPipe pipe.Tx + + // data represents the buffer area where the packet payload is held. + data []byte + + // eventFD is used to notify the peer when transmission is completed. + eventFD eventfd.Eventfd + + // sharedData the memory region to use to enable/disable notifications. + sharedData []byte +} + +// init initializes all state needed by the serverTx queue based on the +// information provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (s *serverRx) init(c *QueueConfig) error { + // Map in all buffers. + packetPipeMem, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) }) + defer cu.Clean() + + completionPipeMem, err := getBuffer(c.RxPipeFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(completionPipeMem) }) + + data, err := getBuffer(c.DataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(data) }) + + sharedData, err := getBuffer(c.SharedDataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(sharedData) }) + + // Duplicate the eventFD so that caller can close it but we can still + // use it. + efd, err := c.EventFD.Dup() + if err != nil { + return err + } + cu.Add(func() { efd.Close() }) + + s.packetPipe.Init(packetPipeMem) + s.completionPipe.Init(completionPipeMem) + s.data = data + s.eventFD = efd + s.sharedData = sharedData + + cu.Release() + return nil +} + +func (s *serverRx) cleanup() { + unix.Munmap(s.packetPipe.Bytes()) + unix.Munmap(s.completionPipe.Bytes()) + unix.Munmap(s.data) + unix.Munmap(s.sharedData) + s.eventFD.Close() +} + +// completionNotificationSize is size in bytes of a completion notification sent +// on the completion queue after a transmitted packet has been handled. +const completionNotificationSize = 8 + +// receive receives a single packet from the packetPipe. +func (s *serverRx) receive() []byte { + desc := s.packetPipe.Pull() + if desc == nil { + return nil + } + + pktInfo := queue.DecodeTxPacketHeader(desc) + contents := make([]byte, 0, pktInfo.Size) + toCopy := pktInfo.Size + for i := 0; i < pktInfo.BufferCount; i++ { + txBuf := queue.DecodeTxBufferHeader(desc, i) + if txBuf.Size <= toCopy { + contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...) + toCopy -= txBuf.Size + continue + } + contents = append(contents, s.data[txBuf.Offset:][:toCopy]...) + break + } + + // Flush to let peer know that slots queued for transmission have been handled + // and its free to reuse the slots. + s.packetPipe.Flush() + // Encode packet completion. + b := s.completionPipe.Push(completionNotificationSize) + queue.EncodeTxCompletion(b, pktInfo.ID) + s.completionPipe.Flush() + return contents +} + +func (s *serverRx) waitForPackets() { + s.eventFD.Wait() +} diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go new file mode 100644 index 000000000..13a82903f --- /dev/null +++ b/pkg/tcpip/link/sharedmem/server_tx.go @@ -0,0 +1,175 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/eventfd" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +// serverTx represents the server end of the sharedmem queue and is used to send +// packets to the peer in the buffers posted by the peer in the fillPipe. +type serverTx struct { + // fillPipe represents the receive end of the pipe that carries the RxBuffers + // posted by the peer. + fillPipe pipe.Rx + + // completionPipe represents the transmit end of the pipe that carries the + // descriptors for filled RxBuffers. + completionPipe pipe.Tx + + // data represents the buffer area where the packet payload is held. + data []byte + + // eventFD is used to notify the peer when fill requests are fulfilled. + eventFD eventfd.Eventfd + + // sharedData the memory region to use to enable/disable notifications. + sharedData []byte +} + +// init initializes all tstate needed by the serverTx queue based on the +// information provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (s *serverTx) init(c *QueueConfig) error { + // Map in all buffers. + fillPipeMem, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) }) + defer cu.Clean() + + completionPipeMem, err := getBuffer(c.RxPipeFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(completionPipeMem) }) + + data, err := getBuffer(c.DataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(data) }) + + sharedData, err := getBuffer(c.SharedDataFD) + if err != nil { + return err + } + cu.Add(func() { unix.Munmap(sharedData) }) + + // Duplicate the eventFD so that caller can close it but we can still + // use it. + efd, err := c.EventFD.Dup() + if err != nil { + return err + } + cu.Add(func() { efd.Close() }) + + cu.Release() + + s.fillPipe.Init(fillPipeMem) + s.completionPipe.Init(completionPipeMem) + s.data = data + s.eventFD = efd + s.sharedData = sharedData + + return nil +} + +func (s *serverTx) cleanup() { + unix.Munmap(s.fillPipe.Bytes()) + unix.Munmap(s.completionPipe.Bytes()) + unix.Munmap(s.data) + unix.Munmap(s.sharedData) + s.eventFD.Close() +} + +// fillPacket copies the data in the provided views into buffers pulled from the +// fillPipe and returns a slice of RxBuffers that contain the copied data as +// well as the total number of bytes copied. +// +// To avoid allocations the filledBuffers are appended to the buffers slice +// which will be grown as required. +func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) { + filledBuffers = buffers[:0] + // fillBuffer copies as much of the views as possible into the provided buffer + // and returns any left over views (if any). + fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) { + if len(views) == 0 { + return nil + } + availBytes := buffer.Size + copied := uint64(0) + for availBytes > 0 && len(views) > 0 { + n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0]) + views[0].TrimFront(n) + if !views[0].IsEmpty() { + break + } + views = views[1:] + copied += uint64(n) + availBytes -= uint32(n) + } + buffer.Size = uint32(copied) + return views + } + + for len(views) > 0 { + var b []byte + // Spin till we get a free buffer reposted by the peer. + for { + if b = s.fillPipe.Pull(); b != nil { + break + } + } + rxBuffer := queue.DecodeRxBufferHeader(b) + // Copy the packet into the posted buffer. + views = fillBuffer(&rxBuffer, views) + totalCopied += rxBuffer.Size + filledBuffers = append(filledBuffers, rxBuffer) + } + + return filledBuffers, totalCopied +} + +func (s *serverTx) transmit(views []buffer.View) bool { + buffers := make([]queue.RxBuffer, 8) + buffers, totalCopied := s.fillPacket(views, buffers) + b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers))) + if b == nil { + return false + } + queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */) + for i := 0; i < len(buffers); i++ { + queue.EncodeRxCompletionBuffer(b, i, buffers[i]) + } + s.completionPipe.Flush() + s.fillPipe.Flush() + return true +} + +func (s *serverTx) notify() { + s.eventFD.Notify() +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 30cf659b8..bcb37a465 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -24,14 +24,16 @@ package sharedmem import ( + "fmt" "sync/atomic" - "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,7 +49,7 @@ type QueueConfig struct { // EventFD is a file descriptor for the event that is signaled when // data is becomes available in this queue. - EventFD int + EventFD eventfd.Eventfd // TxPipeFD is a file descriptor for the tx pipe associated with the // queue. @@ -63,16 +65,97 @@ type QueueConfig struct { SharedDataFD int } +// FDs returns the FD's in the QueueConfig as a slice of ints. This must +// be used in conjunction with QueueConfigFromFDs to ensure the order +// of FDs matches when reconstructing the config when serialized or sent +// as part of control messages. +func (q *QueueConfig) FDs() []int { + return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD} +} + +// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each +// entry represents an file descriptor. The order of FDs in the slice must be in +// the order specified below for the config to be valid. QueueConfig.FDs() +// should be used when the config needs to be serialized or sent as part of a +// control message to ensure the correct order. +func QueueConfigFromFDs(fds []int) (QueueConfig, error) { + if len(fds) != 5 { + return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds)) + } + return QueueConfig{ + DataFD: fds[0], + EventFD: eventfd.Wrap(fds[1]), + TxPipeFD: fds[2], + RxPipeFD: fds[3], + SharedDataFD: fds[4], + }, nil +} + +// Options specify the details about the sharedmem endpoint to be created. +type Options struct { + // MTU is the mtu to use for this endpoint. + MTU uint32 + + // BufferSize is the size of each scatter/gather buffer that will hold packet + // data. + // + // NOTE: This directly determines number of packets that can be held in + // the ring buffer at any time. This does not have to be sized to the MTU as + // the shared memory queue design allows usage of more than one buffer to be + // used to make up a given packet. + BufferSize uint32 + + // LinkAddress is the link address for this endpoint (required). + LinkAddress tcpip.LinkAddress + + // TX is the transmit queue configuration for this shared memory endpoint. + TX QueueConfig + + // RX is the receive queue configuration for this shared memory endpoint. + RX QueueConfig + + // PeerFD is the fd for the connected peer which can be used to detect + // peer disconnects. + PeerFD int + + // OnClosed is a function that is called when the endpoint is being closed + // (probably due to peer going away) + OnClosed func(err tcpip.Error) + + // TXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityTXChecksumOffload. + TXChecksumOffload bool + + // RXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityRXChecksumOffload. + RXChecksumOffload bool +} + type endpoint struct { // mtu (maximum transmission unit) is the maximum size of a packet. + // mtu is immutable. mtu uint32 // bufferSize is the size of each individual buffer. + // bufferSize is immutable. bufferSize uint32 // addr is the local address of this endpoint. + // addr is immutable. addr tcpip.LinkAddress + // peerFD is an fd to the peer that can be used to detect when the + // peer is gone. + // peerFD is immutable. + peerFD int + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // hdrSize is the size of the link layer header if any. + // hdrSize is immutable. + hdrSize uint32 + // rx is the receive queue. rx rx @@ -83,34 +166,55 @@ type endpoint struct { // Wait group used to indicate that all workers have stopped. completed sync.WaitGroup + // onClosed is a function to be called when the FD's peer (if any) closes + // its end of the communication pipe. + onClosed func(tcpip.Error) + // mu protects the following fields. mu sync.Mutex // tx is the transmit queue. + // +checklocks:mu tx tx // workerStarted specifies whether the worker goroutine was started. + // +checklocks:mu workerStarted bool } // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { +func New(opts Options) (stack.LinkEndpoint, error) { e := &endpoint{ - mtu: mtu, - bufferSize: bufferSize, - addr: addr, + mtu: opts.MTU, + bufferSize: opts.BufferSize, + addr: opts.LinkAddress, + peerFD: opts.PeerFD, + onClosed: opts.OnClosed, } - if err := e.tx.init(bufferSize, &tx); err != nil { + if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil { return nil, err } - if err := e.rx.init(bufferSize, &rx); err != nil { + if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil { e.tx.cleanup() return nil, err } + e.caps = stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + e.caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + e.caps |= stack.CapabilityTXChecksumOffload + } + + if opts.LinkAddress != "" { + e.hdrSize = header.EthernetMinimumSize + e.caps |= stack.CapabilityResolutionRequired + } return e, nil } @@ -119,13 +223,13 @@ func (e *endpoint) Close() { // Tell dispatch goroutine to stop, then write to the eventfd so that // it wakes up in case it's sleeping. atomic.StoreUint32(&e.stopRequested, 1) - unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + e.rx.eventFD.Notify() // Cleanup the queues inline if the worker hasn't started yet; we also // know it won't start from now on because stopRequested is set to 1. e.mu.Lock() + defer e.mu.Unlock() workerPresent := e.workerStarted - e.mu.Unlock() if !workerPresent { e.tx.cleanup() @@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { e.workerStarted = true e.completed.Add(1) + + // Spin up a goroutine to monitor for peer shutdown. + if e.peerFD >= 0 { + e.completed.Add(1) + go func() { + defer e.completed.Done() + b := make([]byte, 1) + // When sharedmem endpoint is in use the peerFD is never used for any data + // transfer and this Read should only return if the peer is shutting down. + _, err := rawfile.BlockingRead(e.peerFD, b) + if e.onClosed != nil { + e.onClosed(err) + } + }() + } + // Link endpoints are not savable. When transportation endpoints // are saved, they stop sending outgoing packets and all // incoming packets are rejected. @@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool { // MTU implements stack.LinkEndpoint.MTU. It returns the value initialized // during construction. func (e *endpoint) MTU() uint32 { - return e.mtu - header.EthernetMinimumSize + return e.mtu - e.hdrSize } // Capabilities implements stack.LinkEndpoint.Capabilities. -func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps } // MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the // ethernet frame header size. -func (*endpoint) MaxHeaderLength() uint16 { - return header.EthernetMinimumSize +func (e *endpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) } // LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local @@ -202,17 +322,18 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net eth.Encode(ethHdr) } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) +// WriteRawPacket implements stack.LinkEndpoint. +func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } + +// +checklocks:e.mu +func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if e.addr != "" { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } views := pkt.Views() // Transmit the packet. - e.mu.Lock() ok := e.tx.transmit(views...) - e.mu.Unlock() - if !ok { return &tcpip.ErrWouldBlock{} } @@ -220,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol return nil } +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + if err := e.writePacketLocked(r, protocol, pkt); err != nil { + return err + } + e.tx.notify() + return nil +} + // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") +func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + n := 0 + var err tcpip.Error + e.mu.Lock() + defer e.mu.Unlock() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil { + break + } + n++ + } + // WritePackets never returns an error if it successfully transmitted at least + // one packet. + if err != nil && n == 0 { + return 0, err + } + e.tx.notify() + return n, nil } // dispatchLoop reads packets from the rx queue in a loop and dispatches them @@ -265,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { Data: buffer.View(b).ToVectorisedView(), }) - hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) - if !ok { - continue + var src, dst tcpip.LinkAddress + var proto tcpip.NetworkProtocolNumber + if e.addr != "" { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + continue + } + eth := header.Ethernet(hdr) + src = eth.SourceAddress() + dst = eth.DestinationAddress() + proto = eth.Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + continue + } + switch header.IPVersion(h) { + case header.IPv4Version: + proto = header.IPv4ProtocolNumber + case header.IPv6Version: + proto = header.IPv6ProtocolNumber + default: + continue + } } - eth := header.Ethernet(hdr) // Send packet up the stack. - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) + d.DeliverNetworkPacket(src, dst, proto, pkt) } + e.mu.Lock() + defer e.mu.Unlock() + // Clean state. e.tx.cleanup() e.rx.cleanup() diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go new file mode 100644 index 000000000..ccc84989d --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go @@ -0,0 +1,333 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type serverEndpoint struct { + // mtu (maximum transmission unit) is the maximum size of a packet. + // mtu is immutable. + mtu uint32 + + // bufferSize is the size of each individual buffer. + // bufferSize is immutable. + bufferSize uint32 + + // addr is the local address of this endpoint. + // addr is immutable + addr tcpip.LinkAddress + + // rx is the receive queue. + rx serverRx + + // stopRequested is to be accessed atomically only, and determines if the + // worker goroutines should stop. + stopRequested uint32 + + // Wait group used to indicate that all workers have stopped. + completed sync.WaitGroup + + // peerFD is an fd to the peer that can be used to detect when the peer is + // gone. + // peerFD is immutable. + peerFD int + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // hdrSize is the size of the link layer header if any. + // hdrSize is immutable. + hdrSize uint32 + + // onClosed is a function to be called when the FD's peer (if any) closes its + // end of the communication pipe. + onClosed func(tcpip.Error) + + // mu protects the following fields. + mu sync.Mutex + + // tx is the transmit queue. + // +checklocks:mu + tx serverTx + + // workerStarted specifies whether the worker goroutine was started. + // +checklocks:mu + workerStarted bool +} + +// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be +// broken up into buffers of "bufferSize" bytes. +func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) { + e := &serverEndpoint{ + mtu: opts.MTU, + bufferSize: opts.BufferSize, + addr: opts.LinkAddress, + peerFD: opts.PeerFD, + onClosed: opts.OnClosed, + } + + if err := e.tx.init(&opts.RX); err != nil { + return nil, err + } + + if err := e.rx.init(&opts.TX); err != nil { + e.tx.cleanup() + return nil, err + } + + e.caps = stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + e.caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + e.caps |= stack.CapabilityTXChecksumOffload + } + + if opts.LinkAddress != "" { + e.hdrSize = header.EthernetMinimumSize + e.caps |= stack.CapabilityResolutionRequired + } + + return e, nil +} + +// Close frees all resources associated with the endpoint. +func (e *serverEndpoint) Close() { + // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes + // up in case it's sleeping. + atomic.StoreUint32(&e.stopRequested, 1) + e.rx.eventFD.Notify() + + // Cleanup the queues inline if the worker hasn't started yet; we also know it + // won't start from now on because stopRequested is set to 1. + e.mu.Lock() + defer e.mu.Unlock() + workerPresent := e.workerStarted + + if !workerPresent { + e.tx.cleanup() + e.rx.cleanup() + } +} + +// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have +// stopped after a Close() call. +func (e *serverEndpoint) Wait() { + e.completed.Wait() +} + +// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that +// reads packets from the rx queue. +func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { + e.workerStarted = true + e.completed.Add(1) + if e.peerFD >= 0 { + e.completed.Add(1) + // Spin up a goroutine to monitor for peer shutdown. + go func() { + b := make([]byte, 1) + // When sharedmem endpoint is in use the peerFD is never used for any + // data transfer and this Read should only return if the peer is + // shutting down. + _, err := rawfile.BlockingRead(e.peerFD, b) + if e.onClosed != nil { + e.onClosed(err) + } + e.completed.Done() + }() + } + // Link endpoints are not savable. When transportation endpoints are saved, + // they stop sending outgoing packets and all incoming packets are rejected. + go e.dispatchLoop(dispatcher) // S/R-SAFE: see above. + } + e.mu.Unlock() +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *serverEndpoint) IsAttached() bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.workerStarted +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *serverEndpoint) MTU() uint32 { + return e.mtu - e.hdrSize +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the +// ethernet frame header size. +func (e *serverEndpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local +// link address. +func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress { + return e.addr +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + ethHdr := &header.EthernetFields{ + DstAddr: remote, + Type: protocol, + } + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) +} + +// WriteRawPacket implements stack.LinkEndpoint. +func (*serverEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +// +checklocks:e.mu +func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + + views := pkt.Views() + ok := e.tx.transmit(views) + if !ok { + return &tcpip.ErrWouldBlock{} + } + + return nil +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + // Transmit the packet. + e.mu.Lock() + defer e.mu.Unlock() + if err := e.writePacketLocked(r, protocol, pkt); err != nil { + return err + } + e.tx.notify() + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *serverEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + n := 0 + var err tcpip.Error + e.mu.Lock() + defer e.mu.Unlock() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil { + break + } + n++ + } + // WritePackets never returns an error if it successfully transmitted at least + // one packet. + if err != nil && n == 0 { + return 0, err + } + e.tx.notify() + return n, nil +} + +// dispatchLoop reads packets from the rx queue in a loop and dispatches them +// to the network stack. +func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) { + for atomic.LoadUint32(&e.stopRequested) == 0 { + b := e.rx.receive() + if b == nil { + e.rx.waitForPackets() + continue + } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + var src, dst tcpip.LinkAddress + var proto tcpip.NetworkProtocolNumber + if e.addr != "" { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + continue + } + eth := header.Ethernet(hdr) + src = eth.SourceAddress() + dst = eth.DestinationAddress() + proto = eth.Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + continue + } + switch header.IPVersion(h) { + case header.IPv4Version: + proto = header.IPv4ProtocolNumber + case header.IPv6Version: + proto = header.IPv6ProtocolNumber + default: + continue + } + } + // Send packet up the stack. + d.DeliverNetworkPacket(src, dst, proto, pkt) + } + + e.mu.Lock() + defer e.mu.Unlock() + + // Clean state. + e.tx.cleanup() + e.rx.cleanup() + + e.completed.Done() +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go new file mode 100644 index 000000000..1bc58614e --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go @@ -0,0 +1,220 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package sharedmem_server_test + +import ( + "fmt" + "io" + "net" + "net/http" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +const ( + localLinkAddr = "\xde\xad\xbe\xef\x56\x78" + remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" + localIPv4Address = tcpip.Address("\x0a\x00\x00\x01") + remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02") + serverPort = 10001 + + defaultMTU = 1500 + defaultBufferSize = 1500 +) + +type stackOptions struct { + ep stack.LinkEndpoint + addr tcpip.Address +} + +func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) { + st := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + AllowExternalLoopbackTraffic: true, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + AllowExternalLoopbackTraffic: true, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, + }) + nicID := tcpip.NICID(1) + sniffEP := sniffer.New(stackOpts.ep) + opts := stack.NICOptions{Name: "eth0"} + if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil { + return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err) + } + + // Add Protocol Address. + protocolNum := ipv4.ProtocolNumber + routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}} + if len(stackOpts.addr) == 16 { + routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}} + protocolNum = ipv6.ProtocolNumber + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: protocolNum, + AddressWithPrefix: stackOpts.addr.WithPrefix(), + } + if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err) + } + + // Setup route table. + st.SetRouteTable(routeTable) + + return st, nil +} + +func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) { + ep, err := sharedmem.New(sharedmem.Options{ + MTU: defaultMTU, + BufferSize: defaultBufferSize, + LinkAddress: localLinkAddr, + TX: qPair.TXQueueConfig(), + RX: qPair.RXQueueConfig(), + PeerFD: peerFD, + }) + if err != nil { + return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err) + } + st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address}) + if err != nil { + return nil, fmt.Errorf("failed to create client stack: %s", err) + } + return st, nil +} + +func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) { + ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{ + MTU: defaultMTU, + BufferSize: defaultBufferSize, + LinkAddress: remoteLinkAddr, + TX: qPair.TXQueueConfig(), + RX: qPair.RXQueueConfig(), + PeerFD: peerFD, + }) + if err != nil { + return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err) + } + st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address}) + if err != nil { + return nil, fmt.Errorf("failed to create client stack: %s", err) + } + return st, nil +} + +type testContext struct { + clientStk *stack.Stack + serverStk *stack.Stack + peerFDs [2]int +} + +func newTestContext(t *testing.T) *testContext { + peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0) + if err != nil { + t.Fatalf("failed to create peerFDs: %s", err) + } + q, err := sharedmem.NewQueuePair() + if err != nil { + t.Fatalf("failed to create sharedmem queue: %s", err) + } + clientStack, err := newClientStack(t, q, peerFDs[0]) + if err != nil { + q.Close() + unix.Close(peerFDs[0]) + unix.Close(peerFDs[1]) + t.Fatalf("failed to create client stack: %s", err) + } + serverStack, err := newServerStack(t, q, peerFDs[1]) + if err != nil { + q.Close() + unix.Close(peerFDs[0]) + unix.Close(peerFDs[1]) + clientStack.Close() + t.Fatalf("failed to create server stack: %s", err) + } + return &testContext{ + clientStk: clientStack, + serverStk: serverStack, + peerFDs: peerFDs, + } +} + +func (ctx *testContext) cleanup() { + unix.Close(ctx.peerFDs[0]) + unix.Close(ctx.peerFDs[1]) + ctx.clientStk.Close() + ctx.serverStk.Close() +} + +func TestServerRoundTrip(t *testing.T) { + ctx := newTestContext(t) + defer ctx.cleanup() + listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort} + l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber) + if err != nil { + t.Fatalf("failed to start TCP Listener: %s", err) + } + defer l.Close() + var responseString = "response" + go func() { + http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(responseString)) + })) + }() + + dialFunc := func(address, protocol string) (net.Conn, error) { + return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + Dial: dialFunc, + }, + } + serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort) + response, err := httpClient.Get(serverURL) + if err != nil { + t.Fatalf("httpClient.Get(\"/\") failed: %s", err) + } + if got, want := response.StatusCode, http.StatusOK; got != want { + t.Fatalf("unexpected status code got: %d, want: %d", got, want) + } + body, err := io.ReadAll(response.Body) + if err != nil { + t.Fatalf("io.ReadAll(response.Body) failed: %s", err) + } + response.Body.Close() + if got, want := string(body), responseString; got != want { + t.Fatalf("unexpected response got: %s, want: %s", got, want) + } +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index d6d953085..66ffc33b8 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -19,9 +19,7 @@ package sharedmem import ( "bytes" - "io/ioutil" "math/rand" - "os" "strings" "testing" "time" @@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress t: t, packetCh: make(chan struct{}, 1000000), } - c.txCfg = createQueueFDs(t, queueSizes{ + c.txCfg, err = createQueueFDs(queueSizes{ dataSize: queueDataSize, txPipeSize: queuePipeSize, rxPipeSize: queuePipeSize, sharedDataSize: 4096, }) - - c.rxCfg = createQueueFDs(t, queueSizes{ + if err != nil { + t.Fatalf("createQueueFDs for tx failed: %s", err) + } + c.rxCfg, err = createQueueFDs(queueSizes{ dataSize: queueDataSize, txPipeSize: queuePipeSize, rxPipeSize: queuePipeSize, sharedDataSize: 4096, }) + if err != nil { + t.Fatalf("createQueueFDs for rx failed: %s", err) + } initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(Options{ + MTU: mtu, + BufferSize: bufferSize, + LinkAddress: addr, + TX: c.txCfg, + RX: c.rxCfg, + PeerFD: -1, + }) if err != nil { t.Fatalf("New failed: %v", err) } @@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip. func (c *testContext) cleanup() { c.ep.Close() - closeFDs(&c.txCfg) - closeFDs(&c.rxCfg) + closeFDs(c.txCfg) + closeFDs(c.rxCfg) c.txq.cleanup() c.rxq.cleanup() } @@ -191,69 +201,6 @@ func shuffle(b []int) { } } -func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir, ok := os.LookupEnv("TEST_TMPDIR") - if !ok { - tmpDir = os.Getenv("TMPDIR") - } - f, err := ioutil.TempFile(tmpDir, "sharedmem_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - unix.Unlink(f.Name()) - - if initQueue { - // Write the "slot-free" flag in the initial queue. - _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) - if err != nil { - t.Fatalf("WriteAt failed: %v", err) - } - } - - fd, err := unix.Dup(int(f.Fd())) - if err != nil { - t.Fatalf("Dup failed: %v", err) - } - - if err := unix.Ftruncate(fd, size); err != nil { - unix.Close(fd) - t.Fatalf("Ftruncate failed: %v", err) - } - - return fd -} - -func closeFDs(c *QueueConfig) { - unix.Close(c.DataFD) - unix.Close(c.EventFD) - unix.Close(c.TxPipeFD) - unix.Close(c.RxPipeFD) - unix.Close(c.SharedDataFD) -} - -type queueSizes struct { - dataSize int64 - txPipeSize int64 - rxPipeSize int64 - sharedDataSize int64 -} - -func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { - fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0) - if err != 0 { - t.Fatalf("eventfd failed: %v", error(err)) - } - - return QueueConfig{ - EventFD: int(fd), - DataFD: createFile(t, s.dataSize, false), - TxPipeFD: createFile(t, s.txPipeSize, true), - RxPipeFD: createFile(t, s.rxPipeSize, true), - SharedDataFD: createFile(t, s.sharedDataSize, false), - } -} - // TestSimpleSend sends 1000 packets with random header and payload sizes, // then checks that the right payload is received on the shared memory queues. func TestSimpleSend(t *testing.T) { @@ -672,7 +619,7 @@ func TestSimpleReceive(t *testing.T) { // Push completion. c.pushRxCompletion(uint32(len(contents)), bufs) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") @@ -718,7 +665,7 @@ func TestRxBuffersReposted(t *testing.T) { // Complete the buffer. c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for it to be reposted. bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) @@ -734,7 +681,7 @@ func TestRxBuffersReposted(t *testing.T) { // Complete with two buffers. c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for them to be reposted. for j := 0; j < 2; j++ { @@ -759,7 +706,7 @@ func TestReceivePostingIsFull(t *testing.T) { first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that packet is received. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") @@ -768,7 +715,7 @@ func TestReceivePostingIsFull(t *testing.T) { second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that no packet is received yet, as the worker is blocked trying // to repost. @@ -781,7 +728,7 @@ func TestReceivePostingIsFull(t *testing.T) { // Flush tx queue, which will allow the first buffer to be reposted, // and the second completion to be pulled. c.rxq.tx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Check that second packet completes. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") @@ -803,7 +750,7 @@ func TestCloseWhileWaitingToPost(t *testing.T) { bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + c.rxCfg.EventFD.Notify() // Wait for packet to be indicated. c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index e3210051f..35e5bff12 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -28,10 +29,12 @@ const ( // tx holds all state associated with a tx queue. type tx struct { - data []byte - q queue.Tx - ids idManager - bufs bufferManager + data []byte + q queue.Tx + ids idManager + bufs bufferManager + eventFD eventfd.Eventfd + sharedDataFD int } // init initializes all state needed by the tx queue based on the information @@ -64,7 +67,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error { t.ids.init() t.bufs.init(0, len(data), int(mtu)) t.data = data - + t.eventFD = c.EventFD + t.sharedDataFD = c.SharedDataFD return nil } @@ -142,6 +146,12 @@ func (t *tx) transmit(bufs ...buffer.View) bool { return true } +// notify writes to the tx.eventFD to indicate to the peer that there is data to +// be read. +func (t *tx) notify() { + t.eventFD.Notify() +} + // getBuffer returns a memory region mapped to the full contents of the given // file descriptor. func getBuffer(fd int) ([]byte, error) { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 28a172e71..2afa95af0 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -140,11 +140,6 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } -// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. -func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) -} - func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 4758a99ad..c3e4c3455 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/refs", "//pkg/refsvfs2", "//pkg/sync", - "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index d23210503..fa2131c28 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -174,7 +173,7 @@ func (d *Device) Write(data []byte) (int64, error) { return 0, linuxerr.EBADFD } if !endpoint.IsAttached() { - return 0, syserror.EIO + return 0, linuxerr.EIO } dataLen := int64(len(data)) @@ -249,7 +248,7 @@ func (d *Device) Read() ([]byte, error) { for { info, ok := endpoint.Read() if !ok { - return nil, syserror.ErrWouldBlock + return nil, linuxerr.ErrWouldBlock } v, ok := d.encodePkt(&info) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index a95602aa5..116e4defb 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -59,15 +59,6 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatchGate.Leave() } -// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. -func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if !e.dispatchGate.Enter() { - return - } - e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) - e.dispatchGate.Leave() -} - // Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and // registers with the lower endpoint as its dispatcher so that "e" is called // for inbound packets. @@ -155,3 +146,6 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.lower.AddHeader(local, remote, protocol, pkt) } + +// WriteRawPacket implements stack.LinkEndpoint. +func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index a71400ee9..b0e4237bd 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -80,6 +80,11 @@ func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBuffe return pkts.Len(), nil } +// WriteRawPacket implements stack.LinkEndpoint. +func (*countedEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { panic("unimplemented") diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 7b1ff44f4..c0179104a 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -23,8 +23,10 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 6515c31e5..e08243547 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -272,7 +272,6 @@ type protocol struct { func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } func (p *protocol) MinimumPacketSize() int { return header.ARPSize } -func (p *protocol) DefaultPrefixLen() int { return 0 } func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 5fcbfeaa2..061cc35ae 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext t.Fatalf("CreateNIC failed: %s", err) } - if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %s", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: stackAddr.WithPrefix(), + } + if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } tc.s.SetRouteTable([]tcpip.Route{{ @@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) { } if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: test.nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index 605e9ef8d..4d4d98caf 100644 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -101,6 +101,11 @@ func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return heade func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { } +// WriteRawPacket implements stack.LinkEndpoint. +func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // MakeRandPkt generates a randomized packet. transportHeaderLength indicates // how many random bytes will be copied in the Transport Header. // extraHeaderReserveLength indicates how much extra space will be reserved for diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 771b9173a..87f650661 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "bytes" "fmt" "strings" "testing" @@ -32,8 +33,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) const nicID = 1 @@ -230,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv4.ProtocolNumber, local) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: local.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, err + } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, @@ -246,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv6.ProtocolNumber, local) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: local.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + return nil, err + } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, @@ -269,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c } v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) + if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err) } v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) + if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err) } return s, e @@ -710,8 +725,8 @@ func TestReceive(t *testing.T) { if !ok { t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) } - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err) } else { ep.DecRef() } @@ -882,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -968,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -1234,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") } addr := localIPv6Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -1301,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name string protoFactory stack.NetworkProtocolFactory protoNum tcpip.NetworkProtocolNumber - nicAddr tcpip.Address + nicAddr tcpip.AddressWithPrefix remoteAddr tcpip.Address pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) @@ -1311,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) @@ -1352,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with IHL too small", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) @@ -1376,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 too small", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) @@ -1394,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 minimum size", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) @@ -1430,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with options", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) @@ -1475,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv4 with options and data across views", protoFactory: ipv4.NewProtocol, protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, + nicAddr: localIPv4AddrWithPrefix, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) @@ -1516,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) @@ -1556,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 with extension header", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) @@ -1601,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 minimum size", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) @@ -1636,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { name: "IPv6 too small", protoFactory: ipv6.NewProtocol, protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, + nicAddr: localIPv6AddrWithPrefix, remoteAddr: remoteIPv6Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) @@ -1660,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }{ { name: "unspecified source", - srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))), }, { name: "random source", - srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))), }, } @@ -1677,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.protoNum, + AddressWithPrefix: test.nicAddr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) - r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err) } defer r.Release() @@ -2032,3 +2051,97 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { }) } } + +func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.AddressWithPrefix + payloadOffset int + }{ + { + name: "IPv4", + proto: header.IPv4ProtocolNumber, + addr: localIPv4AddrWithPrefix, + payloadOffset: header.IPv4MinimumSize, + }, + { + name: "IPv6", + proto: header.IPv6ProtocolNumber, + addr: localIPv6AddrWithPrefix, + payloadOffset: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + RawFactory: raw.EndpointFactory{}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.proto, + AddressWithPrefix: test.addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.addr.Subnet(), + NIC: nicID, + }, + }) + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) + } + defer ep.Close() + + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.addr.Address, + }, + } + data := []byte{1, 2, 3, 4} + var r bytes.Reader + r.Reset(data) + if n, err := ep.Write(&r, writeOpts); err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } else if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) + } + + // Wait for the endpoint to become readable. + <-ch + + var w bytes.Buffer + rr, err := ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if err != nil { + t.Fatalf("ep.Read(...): %s", err) + } + if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" { + t.Errorf("payload mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" { + t.Errorf("remote addr mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 2aa38eb98..1c3b0887f 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -167,14 +167,17 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet p := hdr.TransportProtocol() dstAddr := hdr.DestinationAddress() // Skip the ip header, then deliver the error. - pkt.Data().DeleteFront(hlen) + if _, ok := pkt.Data().Consume(hlen); !ok { + panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen)) + } e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived // ICMP packets don't have their TransportHeader fields set. See - // icmp/protocol.go:protocol.Parse for a full explanation. + // icmp/protocol.go:protocol.Parse for a full explanation. Not all ICMP types + // require consuming the header, so we only call PullUp. v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { received.invalid.Increment() @@ -240,15 +243,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.echoRequest.Increment() - sent := e.stats.icmp.packetsSent - if !e.protocol.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return - } - // DeliverTransportPacket will take ownership of pkt so don't use it beyond // this point. Make a deep copy of the data before pkt gets sent as we will - // be modifying fields. + // be modifying fields. Both the ICMP header (with its type modified to + // EchoReply) and payload are reused in the reply packet. // // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no // waiting endpoints. Consider moving responsibility for doing the copy to @@ -281,6 +279,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } defer r.Release() + sent := e.stats.icmp.packetsSent + if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) { + sent.rateLimited.Increment() + return + } + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the // header information, we may have to change this code to handle the // ICMP header no longer being in the data buffer. @@ -331,6 +335,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.echoReply.Increment() + // ICMP sockets expect the ICMP header to be present, so we don't consume + // the ICMP header. e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: @@ -338,7 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { mtu := h.MTU() code := h.Code() - pkt.Data().DeleteFront(header.ICMPv4MinimumSize) + if _, ok := pkt.Data().Consume(header.ICMPv4MinimumSize); !ok { + panic("could not consume ICMPv4MinimumSize bytes") + } switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) @@ -562,13 +570,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. @@ -606,6 +607,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip } } + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) { + switch reason := reason.(type) { + case *icmpReasonPortUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonProtoUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetworkUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0 + case *icmpReasonFragmentationNeeded: + return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0 + case *icmpReasonTTLExceeded: + return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0 + case *icmpReasonParamProblem: + return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + }() + + if !p.allowICMPReply(icmpType, icmpCode) { + sent.rateLimited.Increment() + return nil + } + // Now work out how much of the triggering packet we should return. // As per RFC 1812 Section 4.3.2.3 // @@ -658,44 +688,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonProtoUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetworkUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4NetUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4HostUnreachable) - counter = sent.dstUnreachable - case *icmpReasonFragmentationNeeded: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) - counter = sent.dstUnreachable - case *icmpReasonTTLExceeded: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.timeExceeded - case *icmpReasonParamProblem: - icmpHdr.SetType(header.ICMPv4ParamProblem) - icmpHdr.SetCode(header.ICMPv4UnusedCode) - icmpHdr.SetPointer(reason.pointer) - counter = sent.paramProblem - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetCode(icmpCode) + icmpHdr.SetType(icmpType) + icmpHdr.SetPointer(pointer) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 4bd6f462e..c6576fcbc 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma // cycles. func TestIGMPV1Present(t *testing.T) { e, s, clock := createStack(t, true) - addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength} - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { @@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) { // The initial set of IGMP reports that were queued should be sent once an // address is assigned. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: stackAddr, + PrefixLen: defaultPrefixLength, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if got := reportStat.Value(); got != 1 { t.Errorf("got reportStat.Value() = %d, want = 1", got) @@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) { t.Run(test.name, func(t *testing.T) { e, s, _ := createStack(t, true) for _, address := range test.stackAddresses { - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: address, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } stats := s.Stats() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 44c85bdb8..9b71738ae 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -167,6 +167,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { return nil } +func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + ep, ok := p.mu.eps[id] + return ep, ok +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -240,7 +247,7 @@ func (e *endpoint) Enable() tcpip.Error { } // Create an endpoint to receive broadcast packets on this interface. - ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { return err } @@ -419,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -459,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -542,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -569,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -710,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -737,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -746,7 +753,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. - newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader())) + newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength())) + newHdr := header.IPv4(newPkt.NetworkHeader().View()) // As per RFC 791 page 30, Time to Live, // @@ -755,12 +763,19 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // Even if no local information is available on the time actually // spent, the field must be decremented by 1. newHdr.SetTTL(ttl - 1) + // We perform a full checksum as we may have updated options above. The IP + // header is relatively small so this is not expected to be an expensive + // operation. + newHdr.SetChecksum(0) + newHdr.SetChecksum(^newHdr.CalculateChecksum()) - switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(newHdr).ToVectorisedView(), - IsForwardedPacket: true, - })); err.(type) { + forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID()) + if !ok { + // The interface was removed after we obtained the route. + return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}} + } + + switch err := forwardToEp.writePacket(r, newPkt, true /* headerIncluded */); err.(type) { case nil: return nil case *tcpip.ErrMessageTooLong: @@ -826,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -856,6 +871,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum } func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { + pkt.NICID = e.nic.ID() + // Raw socket packets are delivered based solely on the transport protocol // number. We only require that the packet be valid IPv4, and that they not // be fragmented. @@ -863,7 +880,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } - pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() @@ -924,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -1074,11 +1090,11 @@ func (e *endpoint) Close() { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties) if err == nil { e.mu.igmp.sendQueuedReports() } @@ -1199,6 +1215,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv4Type]struct{} } // defaultTTL is the current default TTL for the protocol. Only the @@ -1225,11 +1244,6 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv4MinimumSize } -// DefaultPrefixLen returns the IPv4 default prefix length. -func (p *protocol) DefaultPrefixLen() int { - return header.IPv4AddressSize * 8 -} - // ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv4(v) @@ -1319,6 +1333,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type and code may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool { + // Mimic linux and never rate limit for PMTU discovery. + // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288 + if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded { + return true + } + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { @@ -1398,6 +1429,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) + // Set ICMP rate limiting to Linux defaults. + // See https://man7.org/linux/man-pages/man7/icmp.7.html. + p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{ + header.ICMPv4DstUnreachable: struct{}{}, + header.ICMPv4SrcQuench: struct{}{}, + header.ICMPv4TimeExceeded: struct{}{}, + header.ICMPv4ParamProblem: struct{}{}, + } return p } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 73407be67..ef91245d7 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) { defer ep.Close() // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if err := ep.Connect(randomAddr); err != nil { t.Errorf("Connect failed: %v", err) @@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err) } expectedEmittedPacketCount := 1 @@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) } // Default routes for IPv4 so ICMP can find a route to the remote @@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x02") tos = 0 ident = 1 ttl = 48 @@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } for _, f := range test.fragments { @@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x02") tos = 0 ident = 1 ttl = 48 @@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) { const ( nicID = 1 - addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 + addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1 + addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2 + addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3 ) // Build and return a UDP header containing payload. @@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } wq := waiter.Queue{} @@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" + src = tcpip.Address("\x10\x00\x00\x01") + dst = tcpip.Address("\x10\x00\x00\x02") ) - if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { mask := tcpip.AddressMask(header.IPv4Broadcast) @@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) } s.SetRouteTable([]tcpip.Route{ @@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ @@ -3349,3 +3373,139 @@ func TestCloseLocking(t *testing.T) { } }() } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv4ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.UDPProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index f99cbf8f3..f814926a3 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -51,6 +51,7 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 94caaae6c..ff23d48e7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().DeleteFront(header.IPv6MinimumSize) + if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok { + panic("could not consume IPv6MinimumSize bytes") + } if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) + if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok { + panic("could not consume IPv6FragmentHeaderSize bytes") + } } e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) @@ -325,7 +329,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.packetTooBig.Increment() - hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize) + hdr, ok := pkt.Data().Consume(header.ICMPv6PacketTooBigMinimumSize) if !ok { received.invalid.Increment() return @@ -334,18 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r if err != nil { networkMTU = 0 } - pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() - hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize) + hdr, ok := pkt.Data().Consume(header.ICMPv6DstUnreachableMinimumSize) if !ok { received.invalid.Increment() return } code := header.ICMPv6(hdr).Code() - pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) switch code { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) @@ -692,6 +694,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } defer r.Release() + if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) { + sent.rateLimited.Increment() + return + } + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data().ExtractVV(), @@ -1174,13 +1181,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. // Unfortunately at this time ICMP Packets do not have a transport @@ -1198,6 +1198,33 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip } } + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) { + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer + case *icmpReasonPortUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0 + case *icmpReasonPacketTooBig: + return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0 + case *icmpReasonHopLimitExceeded: + return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0 + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + }() + + if !p.allowICMPReply(icmpType) { + sent.rateLimited.Increment() + return nil + } + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() // As per RFC 4443 section 2.4 @@ -1232,40 +1259,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonParameterProblem: - icmpHdr.SetType(header.ICMPv6ParamProblem) - icmpHdr.SetCode(reason.code) - icmpHdr.SetTypeSpecific(reason.pointer) - counter = sent.paramProblem - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6AddressUnreachable) - counter = sent.dstUnreachable - case *icmpReasonPacketTooBig: - icmpHdr.SetType(header.ICMPv6PacketTooBig) - icmpHdr.SetCode(header.ICMPv6UnusedCode) - counter = sent.packetTooBig - case *icmpReasonHopLimitExceeded: - icmpHdr.SetType(header.ICMPv6TimeExceeded) - icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv6TimeExceeded) - icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) - counter = sent.timeExceeded - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetType(icmpType) + icmpHdr.SetCode(icmpCode) + icmpHdr.SetTypeSpecific(typeSpecific) + dataRange := newPkt.Data().AsRange() icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 7c2a3e56b..03d9f425c 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -225,8 +226,8 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") } addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } @@ -407,8 +408,12 @@ func newTestContext(t *testing.T) *testContext { if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } - if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) + llProtocolAddr0 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err) } c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) @@ -416,8 +421,12 @@ func newTestContext(t *testing.T) *testContext { if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) + llProtocolAddr1 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr1.WithPrefix(), + } + if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err) } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -690,8 +699,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -883,8 +896,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -1065,8 +1082,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -1240,8 +1261,12 @@ func TestLinkAddressRequest(t *testing.T) { } if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: test.nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -1411,12 +1436,14 @@ func TestPacketQueing(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, Clock: clock, }) + // Make sure ICMP rate limiting doesn't get in our way. + s.SetICMPLimit(rate.Inf) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) + if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err) } s.SetRouteTable([]tcpip.Route{ @@ -1669,8 +1696,12 @@ func TestCallsToNeighborCache(t *testing.T) { if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } { @@ -1704,8 +1735,8 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") } addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) } else { ep.DecRef() } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b1aec5312..600e805f8 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1024,7 +1024,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. - newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader())) + newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength())) + newHdr := header.IPv6(newPkt.NetworkHeader().View()) // As per RFC 8200 section 3, // @@ -1032,11 +1033,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) - switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(newHdr).ToVectorisedView(), - IsForwardedPacket: true, - })); err.(type) { + forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID()) + if !ok { + // The interface was removed after we obtained the route. + return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}} + } + + switch err := forwardToEp.writePacket(r, newPkt, newPkt.TransportProtocolNumber, true /* headerIncluded */); err.(type) { case nil: return nil case *tcpip.ErrMessageTooLong: @@ -1097,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -1127,11 +1130,12 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum } func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { + pkt.NICID = e.nic.ID() + // Raw socket packets are delivered based solely on the transport protocol // number. We only require that the packet be valid IPv6. e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) - pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() @@ -1179,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1533,19 +1537,22 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. - // Calculate the number of octets parsed from data. We want to remove all - // the data except the unparsed portion located at the end, which its size - // is extHdr.Buf.Size(). + // Calculate the number of octets parsed from data. We want to consume all + // the data except the unparsed portion located at the end, whose size is + // extHdr.Buf.Size(). trim := pkt.Data().Size() - extHdr.Buf.Size() // For unfragmented packets, extHdr still contains the transport header. - // Get rid of it. + // Consume that too. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. trim += pkt.TransportHeader().View().Size() - pkt.Data().DeleteFront(trim) + if _, ok := pkt.Data().Consume(trim); !ok { + stats.MalformedPacketsReceived.Increment() + return fmt.Errorf("could not consume %d bytes", trim) + } stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -1627,12 +1634,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. e.mu.Lock() defer e.mu.Unlock() - return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) + return e.addAndAcquirePermanentAddressLocked(addr, properties) } // addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but @@ -1642,8 +1649,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // solicited-node multicast group and start duplicate address detection. // // Precondition: e.mu must be write locked. -func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { - addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) { + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties) if err != nil { return nil, err } @@ -1986,6 +1993,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv6Type]struct{} } ids []uint32 @@ -1997,7 +2007,8 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - fragmentation *fragmentation.Fragmentation + fragmentation *fragmentation.Fragmentation + icmpRateLimiter *stack.ICMPRateLimiter } // Number returns the ipv6 protocol number. @@ -2010,11 +2021,6 @@ func (p *protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen returns the IPv6 default prefix length. -func (p *protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) @@ -2086,6 +2092,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { return nil } +func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + ep, ok := p.mu.eps[id] + return ep, ok +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -2171,6 +2184,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, @@ -2267,6 +2292,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) p.SetDefaultTTL(DefaultTTL) + // Set default ICMP rate limiting to Linux defaults. + // + // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big) + // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt. + defaultIcmpTypes := make(map[header.ICMPv6Type]struct{}) + for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ { + switch i { + case header.ICMPv6PacketTooBig: + // Do not rate limit packet too big by default. + default: + defaultIcmpTypes[i] = struct{}{} + } + } + p.mu.icmpRateLimitedTypes = defaultIcmpTypes + return p } } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index d2a23fd4f..e5286081e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -41,12 +41,12 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") // The least significant 3 bytes are the same as addr2 so both addr2 and // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" - addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02") + addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03") // Tests use the extension header identifier values as uint8 instead of // header.IPv6ExtensionHeaderIdentifier. @@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { // addr2/addr3 yet as we haven't added those addresses. test.rxf(t, s, e, addr1, snmc, 0) - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err) } // Should receive a packet destined to the solicited node address of // addr2/addr3 now that we have added added addr2. test.rxf(t, s, e, addr1, snmc, 1) - if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr3.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err) } // Should still receive a packet destined to the solicited node address of @@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: test.addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil { @@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Add a default route so that a return packet knows where to go. @@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } wq := waiter.Queue{} @@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") nicID = 1 hoplimit = 255 @@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, @@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) { func TestFragmentReassemblyTimeout(t *testing.T) { const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") nicID = 1 hoplimit = 255 @@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: addr2.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, @@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( - src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") ) - if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: src.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") @@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err) } outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") @@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -3341,7 +3373,8 @@ func TestForwarding(t *testing.T) { ipHeaderLength := header.IPv6MinimumSize icmpHeaderLength := header.ICMPv6MinimumSize - totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen + payloadLength := icmpHeaderLength + test.payloadLength + extHdrLen + totalLength := ipHeaderLength + payloadLength hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) @@ -3359,7 +3392,7 @@ func TestForwarding(t *testing.T) { copy(hdr.Prepend(extHdrLen), extHdrBytes) ip := header.IPv6(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength), + PayloadLength: uint16(payloadLength), TransportProtocol: transportProtocol, HopLimit: test.TTL, SrcAddr: test.sourceAddr, @@ -3489,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) { t.Error(err) } } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::2").To16()), + PrefixLen: 64, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, + Src: host2IPv6Addr.AddressWithPrefix.Address, + Dst: host1IPv6Addr.AddressWithPrefix.Address, + })) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 1, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv6ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv6MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + + // Calculate the UDP checksum and set it. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(nil, sum) + udpH.SetChecksum(^udpH.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 1, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index bc9cf6999..3e5c438d3 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) { // Note, we will still expect to send a report for the global address's // solicited node address from the unspecified address as per RFC 3590 // section 4. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + globalProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: globalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err) } reportCounter++ if got := reportStat.Value(); got != reportCounter { @@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) { // Adding a link-local address should send a report for its solicited node // address and globalMulticastAddr. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + linkLocalProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err) } if dadResolutionTime != 0 { reportCounter++ @@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 8837d66d8..938427420 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) + addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{ + PEB: stack.FirstPrimaryEndpoint, + ConfigType: configType, + Deprecated: deprecated, + }) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index f0186c64e..8297a7e10 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) @@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: nicAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) @@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) @@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } ndpNASize := header.ICMPv6NeighborAdvertMinimumSize @@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) { checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}), )) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: lladdr0.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } checkDADMsg() diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 1b96b1fb8..26640b7ee 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addr := tcpip.AddressWithPrefix{ - Address: stackIPv4Addr, - PrefixLen: defaultIPv4PrefixLength, + addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: stackIPv4Addr, + PrefixLen: defaultIPv4PrefixLength, + }, + } + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(), } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, clock diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 009cab643..05b879543 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -146,8 +146,12 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Add default route. diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index c10b19aa0..a72afadda 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -124,13 +124,13 @@ func main() { log.Fatalf("Bad IP address: %v", addrName) } - var addr tcpip.Address + var addrWithPrefix tcpip.AddressWithPrefix var proto tcpip.NetworkProtocolNumber if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) + addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix() proto = ipv4.ProtocolNumber } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) + addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix() proto = ipv6.ProtocolNumber } else { log.Fatalf("Unknown IP type: %v", addrName) @@ -176,11 +176,15 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) + subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address)))) if err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 6bce3af04..b0b2d0afd 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -57,6 +57,11 @@ type SocketOptionsHandler interface { // OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE. OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) + + // WakeupWriters is invoked when the send buffer size for an endpoint is + // changed. The handler notifies the writers if the send buffer size is + // increased with setsockopt(2) for TCP endpoints. + WakeupWriters() } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -98,6 +103,9 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) { return v } +// WakeupWriters implements SocketOptionsHandler.WakeupWriters. +func (*DefaultSocketOptionsHandler) WakeupWriters() {} + // OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize. func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) { return v @@ -162,10 +170,14 @@ type SocketOptions struct { // message is passed with incoming packets. receiveTClassEnabled uint32 - // receivePacketInfoEnabled is used to specify if more inforamtion is - // provided with incoming packets such as interface index and address. + // receivePacketInfoEnabled is used to specify if more information is + // provided with incoming IPv4 packets. receivePacketInfoEnabled uint32 + // receivePacketInfoEnabled is used to specify if more information is + // provided with incoming IPv6 packets. + receiveIPv6PacketInfoEnabled uint32 + // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets // being written have an IP header and the endpoint should not attach an IP // header. @@ -352,6 +364,16 @@ func (so *SocketOptions) SetReceivePacketInfo(v bool) { storeAtomicBool(&so.receivePacketInfoEnabled, v) } +// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option. +func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool { + return atomic.LoadUint32(&so.receiveIPv6PacketInfoEnabled) != 0 +} + +// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option. +func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) { + storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v) +} + // GetHeaderIncluded gets value for IP_HDRINCL option. func (so *SocketOptions) GetHeaderIncluded() bool { return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0 @@ -626,6 +648,9 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize) } so.sendBufferSize.Store(sendBufferSize) + if notify { + so.handler.WakeupWriters() + } } // GetReceiveBufferSize gets value for SO_RCVBUF option. diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index e0847e58a..6c42ab29b 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -85,6 +85,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/internal/tcp", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/transport/tcpconntrack", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index ae0bb4ace..7e4b5bf74 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We now promote the address. for i, s := range a.mu.primary { if s == addrState { - switch peb { + switch properties.PEB { case CanBePrimaryEndpoint: // The address is already in the primary address list. attemptAddToPrimary = false @@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address case NeverPrimaryEndpoint: a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } break } @@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // Acquire the address before returning it. addrState.mu.refs++ - addrState.mu.deprecated = deprecated - addrState.mu.configType = configType + addrState.mu.deprecated = properties.Deprecated + addrState.mu.configType = properties.ConfigType if attemptAddToPrimary { - switch peb { + switch properties.PEB { case NeverPrimaryEndpoint: case CanBePrimaryEndpoint: a.mu.primary = append(a.mu.primary, addrState) @@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address a.mu.primary[0] = addrState } default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } } @@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() - ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */) if err != nil { // addAndAcquireAddressLocked only returns an error if the address is // already assigned but we just checked above if the address exists so we // expect no error. - panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err)) } // From https://golang.org/doc/faq#nil_error: diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 140f146f6..c55f85743 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { } { - ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { - t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err) } // We don't need the address endpoint. ep.DecRef() diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 068dab7ce..16d295271 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -64,13 +64,21 @@ type tuple struct { // tupleEntry is used to build an intrusive list of tuples. tupleEntry - tupleID - // conn is the connection tracking entry this tuple belongs to. conn *conn // direction is the direction of the tuple. direction direction + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + tupleID tupleID +} + +func (t *tuple) id() tupleID { + t.mu.RLock() + defer t.mu.RUnlock() + return t.tupleID } // tupleID uniquely identifies a connection in one direction. It currently @@ -103,50 +111,43 @@ func (ti tupleID) reply() tupleID { // // +stateify savable type conn struct { + ct *ConnTrack + // original is the tuple in original direction. It is immutable. original tuple - // reply is the tuple in reply direction. It is immutable. + // reply is the tuple in reply direction. reply tuple - // manip indicates if the packet should be manipulated. It is immutable. - // TODO(gvisor.dev/issue/5696): Support updating manipulation type. + mu sync.RWMutex `state:"nosave"` + // Indicates that the connection has been finalized and may handle replies. + // + // +checklocks:mu + finalized bool + // manip indicates if the packet should be manipulated. + // + // +checklocks:mu manip manipType - - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. It is immutable. - tcbHook Hook - - // mu protects all mutable state. - mu sync.Mutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states - // of tcp connection and is protected by mu. + // of tcp connection. + // + // +checklocks:mu tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and - // is updated by each packet on the connection. It is protected by mu. + // is updated by each packet on the connection. // // TODO(gvisor.dev/issue/5939): do not use the ambient clock. + // + // +checklocks:mu lastUsed time.Time `state:".(unixTime)"` } -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn -} - // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now time.Time) bool { const establishedTimeout = 5 * 24 * time.Hour const defaultTimeout = 120 * time.Second - cn.mu.Lock() - defer cn.mu.Unlock() + cn.mu.RLock() + defer cn.mu.RUnlock() if cn.tcb.State() == tcpconntrack.ResultAlive { // Use the same default as Linux, which doesn't delete // established connections for 5(!) days. @@ -159,17 +160,30 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // -// Precondition: cn.mu must be held. -func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. +// +checklocks:cn.mu +func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) { + if pkt.TransportProtocolNumber != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) - } else if hook == cn.tcbHook { + return + } + + switch dir { + case dirOriginal: cn.tcb.UpdateStateOutbound(tcpHeader) - } else { + case dirReply: cn.tcb.UpdateStateInbound(tcpHeader) + default: + panic(fmt.Sprintf("unhandled dir = %d", dir)) } } @@ -194,44 +208,34 @@ type ConnTrack struct { // It is immutable. seed uint32 + mu sync.RWMutex `state:"nosave"` // mu protects the buckets slice, but not buckets' contents. Only take // the write lock if you are modifying the slice or saving for S/R. - mu sync.RWMutex `state:"nosave"` - - // buckets is protected by mu. + // + // +checklocks:mu buckets []bucket } // +stateify savable type bucket struct { - // mu protects tuples. - mu sync.Mutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu tuples tupleList } -// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid -// TCP header. -// -// Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { - netHeader := pkt.Network() - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, &tcpip.ErrUnknownProtocol{} - } - - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return tupleID{}, &tcpip.ErrUnknownProtocol{} +func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { + return tcpHeader, true + } + case header.UDPProtocolNumber: + if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { + return udpHeader, true + } } - return tupleID{ - srcAddr: netHeader.SourceAddress(), - srcPort: tcpHeader.SourcePort(), - dstAddr: netHeader.DestinationAddress(), - dstPort: tcpHeader.DestinationPort(), - transProto: netHeader.TransportProtocol(), - netProto: pkt.NetworkProtocolNumber, - }, nil + return nil, false } func (ct *ConnTrack) init() { @@ -240,167 +244,185 @@ func (ct *ConnTrack) init() { ct.buckets = make([]bucket, numBuckets) } -// connFor gets the conn for pkt if it exists, or returns nil -// if it does not. It returns an error when pkt does not contain a valid TCP -// header. -// TODO(gvisor.dev/issue/6168): Support UDP. -func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil, dirOriginal +func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple { + netHeader := pkt.Network() + transportHeader, ok := getTransportHeader(pkt) + if !ok { + return nil } - return ct.connForTID(tid) -} -func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { - bucket := ct.bucket(tid) - now := time.Now() + tid := tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: transportHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: transportHeader.DestinationPort(), + transProto: pkt.TransportProtocolNumber, + netProto: pkt.NetworkProtocolNumber, + } + + bktID := ct.bucket(tid) ct.mu.RLock() - defer ct.mu.RUnlock() - ct.buckets[bucket].mu.Lock() - defer ct.buckets[bucket].mu.Unlock() - - // Iterate over the tuples in a bucket, cleaning up any unused - // connections we find. - for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { - // Clean up any timed-out connections we happen to find. - if ct.reapTupleLocked(other, bucket, now) { - // The tuple expired. - continue - } - if tid == other.tupleID { - return other.conn, other.direction - } + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + now := time.Now() + if t := bkt.connForTID(tid, now); t != nil { + return t } - return nil, dirOriginal -} + bkt.mu.Lock() + defer bkt.mu.Unlock() -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil + // Make sure a connection wasn't added between when we last checked the + // bucket and acquired the bucket's write lock. + if t := bkt.connForTIDRLocked(tid, now); t != nil { + return t } - if hook != Prerouting && hook != Output { - return nil + + // This is the first packet we're seeing for the connection. Create an entry + // for this new connection. + conn := &conn{ + ct: ct, + original: tuple{tupleID: tid, direction: dirOriginal}, + reply: tuple{tupleID: tid.reply(), direction: dirReply}, + manip: manipNone, + lastUsed: now, } + conn.original.conn = conn + conn.reply.conn = conn + + // For now, we only map an entry for the packet's original tuple as NAT may be + // performed on this connection. Until the packet goes through all the hooks + // and its final address/port is known, we cannot know what the response + // packet's addresses/ports will look like. + // + // This is okay because the destination cannot send its response until it + // receives the packet; the packet will only be received once all the hooks + // have been performed. + // + // See (*conn).finalize. + bkt.tuples.PushFront(&conn.original) + return &conn.original +} - replyTID := tid.reply() - replyTID.srcAddr = address - replyTID.srcPort = port +func (ct *ConnTrack) connForTID(tid tupleID) *tuple { + bktID := ct.bucket(tid) - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil - } - conn = newConn(tid, replyTID, manipDestination, hook) - ct.insertConn(conn) - return conn + ct.mu.RLock() + bkt := &ct.buckets[bktID] + ct.mu.RUnlock() + + return bkt.connForTID(tid, time.Now()) } -func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { - tid, err := packetToTupleID(pkt) - if err != nil { - return nil - } - if hook != Input && hook != Postrouting { - return nil +func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple { + bkt.mu.RLock() + defer bkt.mu.RUnlock() + return bkt.connForTIDRLocked(tid, now) +} + +// +checklocks:bkt.mu +func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple { + for other := bkt.tuples.Front(); other != nil; other = other.Next() { + if tid == other.id() && !other.conn.timedOut(now) { + return other + } } + return nil +} - replyTID := tid.reply() - replyTID.dstAddr = address - replyTID.dstPort = port +func (ct *ConnTrack) finalize(cn *conn) { + tid := cn.reply.id() + id := ct.bucket(tid) - conn, _ := ct.connForTID(tid) - if conn != nil { - // The connection is already tracked. - // TODO(gvisor.dev/issue/5696): Support updating an existing connection. - return nil + ct.mu.RLock() + bkt := &ct.buckets[id] + ct.mu.RUnlock() + + bkt.mu.Lock() + defer bkt.mu.Unlock() + + if t := bkt.connForTIDRLocked(tid, time.Now()); t != nil { + // Another connection for the reply already exists. We can't do much about + // this so we leave the connection cn represents in a state where it can + // send packets but its responses will be mapped to some other connection. + // This may be okay if the connection only expects to send packets without + // any responses. + return } - conn = newConn(tid, replyTID, manipSource, hook) - ct.insertConn(conn) - return conn + + bkt.tuples.PushFront(&cn.reply) } -// insertConn inserts conn into the appropriate table bucket. -func (ct *ConnTrack) insertConn(conn *conn) { - // Lock the buckets in the correct order. - tupleBucket := ct.bucket(conn.original.tupleID) - replyBucket := ct.bucket(conn.reply.tupleID) - ct.mu.RLock() - defer ct.mu.RUnlock() - if tupleBucket < replyBucket { - ct.buckets[tupleBucket].mu.Lock() - ct.buckets[replyBucket].mu.Lock() - } else if tupleBucket > replyBucket { - ct.buckets[replyBucket].mu.Lock() - ct.buckets[tupleBucket].mu.Lock() - } else { - // Both tuples are in the same bucket. - ct.buckets[tupleBucket].mu.Lock() - } - - // Now that we hold the locks, ensure the tuple hasn't been inserted by - // another thread. - // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? - alreadyInserted := false - for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { - if other.tupleID == conn.original.tupleID { - alreadyInserted = true - break +func (cn *conn) finalize() { + { + cn.mu.RLock() + finalized := cn.finalized + cn.mu.RUnlock() + if finalized { + return } } - if !alreadyInserted { - // Add the tuple to the map. - ct.buckets[tupleBucket].tuples.PushFront(&conn.original) - ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + cn.mu.Lock() + finalized := cn.finalized + cn.finalized = true + cn.mu.Unlock() + if finalized { + return } - // Unlocking can happen in any order. - ct.buckets[tupleBucket].mu.Unlock() - if tupleBucket != replyBucket { - ct.buckets[replyBucket].mu.Unlock() // +checklocksforce - } + cn.ct.finalize(cn) } -// handlePacket will manipulate the port and address of the packet if the -// connection exists. Returns whether, after the packet traverses the tables, -// it should create a new entry in the table. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { - if pkt.NatDone { - return false +// performNAT setups up the connection for the specified NAT. +// +// Generally, only the first packet of a connection reaches this method; other +// other packets will be manipulated without needing to modify the connection. +func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) { + cn.performNATIfNoop(port, address, dnat) + cn.handlePacket(pkt, hook, r) +} + +func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) { + cn.mu.Lock() + defer cn.mu.Unlock() + + if cn.finalized { + return } - switch hook { - case Prerouting, Input, Output, Postrouting: - default: - return false + if cn.manip != manipNone { + return } - // TODO(gvisor.dev/issue/6168): Support UDP. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { - return false + cn.reply.mu.Lock() + defer cn.reply.mu.Unlock() + + if dnat { + cn.reply.tupleID.srcAddr = address + cn.reply.tupleID.srcPort = port + cn.manip = manipDestination + } else { + cn.reply.tupleID.dstAddr = address + cn.reply.tupleID.dstPort = port + cn.manip = manipSource } +} - conn, dir := ct.connFor(pkt) - // Connection not found for the packet. - if conn == nil { - // If this is the last hook in the data path for this packet (Input if - // incoming, Postrouting if outgoing), indicate that a connection should be - // inserted by the end of this hook. - return hook == Input || hook == Postrouting +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) { + if pkt.NatDone { + return } - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return false + transportHeader, ok := getTransportHeader(pkt) + if !ok { + return } + netHeader := pkt.Network() + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. @@ -410,49 +432,58 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { updateSRCFields := false + dir := pkt.tuple.direction + + cn.mu.Lock() + defer cn.mu.Unlock() + switch hook { case Prerouting, Output: - if conn.manip == manipDestination { - switch dir { - case dirOriginal: - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - case dirReply: - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - - updateSRCFields = true - } + if cn.manip == manipDestination && dir == dirOriginal { + id := cn.reply.id() + newPort = id.srcPort + newAddr = id.srcAddr + pkt.NatDone = true + } else if cn.manip == manipSource && dir == dirReply { + id := cn.original.id() + newPort = id.srcPort + newAddr = id.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource { - switch dir { - case dirOriginal: - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - - updateSRCFields = true - case dirReply: - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr - } + if cn.manip == manipSource && dir == dirOriginal { + id := cn.reply.id() + newPort = id.dstPort + newAddr = id.dstAddr + updateSRCFields = true + pkt.NatDone = true + } else if cn.manip == manipDestination && dir == dirReply { + id := cn.original.id() + newPort = id.dstPort + newAddr = id.dstAddr + updateSRCFields = true pkt.NatDone = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } + if !pkt.NatDone { - return false + return } fullChecksum := false updatePseudoHeader := false switch hook { - case Prerouting, Input: + case Prerouting: + // Packet came from outside the stack so it must have a checksum set + // already. + fullChecksum = true + updatePseudoHeader = true + case Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { fullChecksum = true @@ -464,7 +495,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { rewritePacket( netHeader, - tcpHeader, + transportHeader, updateSRCFields, fullChecksum, updatePseudoHeader, @@ -472,46 +503,10 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { newAddr, ) - // Update the state of tcb. - conn.mu.Lock() - defer conn.mu.Unlock() - // Mark the connection as having been used recently so it isn't reaped. - conn.lastUsed = time.Now() + cn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) - - return false -} - -// maybeInsertNoop tries to insert a no-op connection entry to keep connections -// from getting clobbered when replies arrive. It only inserts if there isn't -// already a connection for pkt. -// -// This should be called after traversing iptables rules only, to ensure that -// pkt.NatDone is set correctly. -func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { - // If there were a rule applying to this packet, it would be marked - // with NatDone. - if pkt.NatDone { - return - } - - // We only track TCP connections. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { - return - } - - // This is the first packet we're seeing for the TCP connection. Insert - // the noop entry (an identity mapping) so that the response doesn't - // get NATed, breaking the connection. - tid, err := packetToTupleID(pkt) - if err != nil { - return - } - conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) - ct.insertConn(conn) + cn.updateLocked(pkt, dir) } // bucket gets the conntrack bucket for a tupleID. @@ -563,14 +558,15 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim defer ct.mu.RUnlock() for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { idx = (i + start) % len(ct.buckets) - ct.buckets[idx].mu.Lock() - for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + bkt := &ct.buckets[idx] + bkt.mu.Lock() + for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() { checked++ - if ct.reapTupleLocked(tuple, idx, now) { + if ct.reapTupleLocked(tuple, idx, bkt, now) { expired++ } } - ct.buckets[idx].mu.Unlock() + bkt.mu.Unlock() } // We already checked buckets[idx]. idx++ @@ -595,44 +591,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // reapTupleLocked tries to remove tuple and its reply from the table. It // returns whether the tuple's connection has timed out. // -// Preconditions: -// * ct.mu is locked for reading. -// * bucket is locked. -func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { +// Precondition: ct.mu is read locked and bkt.mu is write locked. +// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. +// +checklocks:ct.mu +// +checklocks:bkt.mu +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now time.Time) bool { if !tuple.conn.timedOut(now) { return false } // To maintain lock order, we can only reap these tuples if the reply // appears later in the table. - replyBucket := ct.bucket(tuple.reply()) - if bucket > replyBucket { + replyBktID := ct.bucket(tuple.id().reply()) + if bktID > replyBktID { return true } // Don't re-lock if both tuples are in the same bucket. - differentBuckets := bucket != replyBucket - if differentBuckets { - ct.buckets[replyBucket].mu.Lock() + if bktID != replyBktID { + replyBkt := &ct.buckets[replyBktID] + replyBkt.mu.Lock() + removeConnFromBucket(replyBkt, tuple) + replyBkt.mu.Unlock() + } else { + removeConnFromBucket(bkt, tuple) } // We have the buckets locked and can remove both tuples. + bkt.tuples.Remove(tuple) + return true +} + +// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. +// +checklocks:b.mu +func removeConnFromBucket(b *bucket, tuple *tuple) { if tuple.direction == dirOriginal { - ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + b.tuples.Remove(&tuple.conn.reply) } else { - ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) - } - ct.buckets[bucket].tuples.Remove(tuple) - - // Don't re-unlock if both tuples are in the same bucket. - if differentBuckets { - ct.buckets[replyBucket].mu.Unlock() // +checklocksforce + b.tuples.Remove(&tuple.conn.original) } - - return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -640,17 +640,22 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ srcPort: epID.LocalPort, dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, - transProto: header.TCPProtocolNumber, + transProto: transProto, netProto: netProto, } - conn, _ := ct.connForTID(tid) - if conn == nil { + t := ct.connForTID(tid) + if t == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip != manipDestination { + } + + t.conn.mu.RLock() + defer t.conn.mu.RUnlock() + if t.conn.manip != manipDestination { // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } - return conn.original.dstAddr, conn.original.dstPort, nil + id := t.conn.original.id() + return id.dstAddr, id.dstPort, nil } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 72f66441f..c2f1f4798 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int { return fwdTestNetHeaderLen } -func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } @@ -342,6 +338,10 @@ func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, p return n, nil } +func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} @@ -380,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } // NIC 2 has the link address "b", and added the network address 2. @@ -393,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } nic, ok := s.nics[2] diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go index 3a20839da..99e5d2df7 100644 --- a/pkg/tcpip/stack/icmp_rate_limit.go +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -16,6 +16,7 @@ package stack import ( "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/tcpip" ) const ( @@ -31,11 +32,41 @@ const ( // ICMPRateLimiter is a global rate limiter that controls the generation of // ICMP messages generated by the stack. type ICMPRateLimiter struct { - *rate.Limiter + limiter *rate.Limiter + clock tcpip.Clock } // NewICMPRateLimiter returns a global rate limiter for controlling the rate -// at which ICMP messages are generated by the stack. -func NewICMPRateLimiter() *ICMPRateLimiter { - return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} +// at which ICMP messages are generated by the stack. The returned limiter +// does not apply limits to any ICMP types by default. +func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter { + return &ICMPRateLimiter{ + clock: clock, + limiter: rate.NewLimiter(icmpLimit, icmpBurst), + } +} + +// SetLimit sets a new Limit for the limiter. +func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) { + l.limiter.SetLimitAt(l.clock.Now(), limit) +} + +// Limit returns the maximum overall event rate. +func (l *ICMPRateLimiter) Limit() rate.Limit { + return l.limiter.Limit() +} + +// SetBurst sets a new burst size for the limiter. +func (l *ICMPRateLimiter) SetBurst(burst int) { + l.limiter.SetBurstAt(l.clock.Now(), burst) +} + +// Burst returns the maximum burst size. +func (l *ICMPRateLimiter) Burst() int { + return l.limiter.Burst() +} + +// Allow reports whether one ICMP message may be sent now. +func (l *ICMPRateLimiter) Allow() bool { + return l.limiter.AllowN(l.clock.Now(), 1) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f152c0d83..5808be685 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -264,26 +264,134 @@ const ( chainReturn ) -// Check runs pkt through the rules for hook. It returns true when the packet -// should continue traversing the network stack and false when it should be -// dropped. +// CheckPrerouting performs the prerouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. // -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { - if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool { + const hook = Prerouting + + if it.shouldSkip(pkt.NetworkProtocolNumber) { return true } + + if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil { + pkt.tuple = t + t.conn.handlePacket(pkt, hook, nil /* route */) + } + + return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) +} + +// CheckInput performs the input hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { + const hook = Input + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + if t := pkt.tuple; t != nil { + t.conn.handlePacket(pkt, hook, nil /* route */) + } + + ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) + if t := pkt.tuple; t != nil { + t.conn.finalize() + } + pkt.tuple = nil + return ret +} + +// CheckForward performs the forward hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool { + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) +} + +// CheckOutput performs the output hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool { + const hook = Output + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil { + pkt.tuple = t + t.conn.handlePacket(pkt, hook, r) + } + + return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// CheckPostrouting performs the postrouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool { + const hook = Postrouting + + if it.shouldSkip(pkt.NetworkProtocolNumber) { + return true + } + + if t := pkt.tuple; t != nil { + t.conn.handlePacket(pkt, hook, r) + } + + ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName) + if t := pkt.tuple; t != nil { + t.conn.finalize() + } + pkt.tuple = nil + return ret +} + +func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool { + switch netProto { + case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber: + default: + // IPTables only supports IPv4/IPv6. + return true + } + + it.mu.RLock() + defer it.mu.RUnlock() // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. + return !it.modified +} + +// check runs pkt through the rules for hook. It returns true when the packet +// should continue traversing the network stack and false when it should be +// dropped. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool { it.mu.RLock() defer it.mu.RUnlock() - if !it.modified { - return true - } - - // Packets are manipulated only if connection and matching - // NAT rule exists. - shouldTrack := it.connections.handlePacket(pkt, hook, r) // Go through each table containing the hook. priorities := it.priorities[hook] @@ -300,7 +408,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -311,7 +419,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v { case RuleAccept: continue case RuleDrop: @@ -327,21 +435,6 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr } } - // If this connection should be tracked, try to add an entry for it. If - // traversing the nat table didn't end in adding an entry, - // maybeInsertNoop will add a no-op entry for the connection. This is - // needeed when establishing connections so that the SYN/ACK reply to an - // outgoing SYN is delivered to the correct endpoint rather than being - // redirected by a prerouting rule. - // - // From the iptables documentation: "If there is no rule, a `null' - // binding is created: this usually does not map the packet, but exists - // to ensure we don't map another stream over an existing one." - if shouldTrack { - it.connections.maybeInsertNoop(pkt, hook) - } - - // Every table returned Accept. return true } @@ -375,19 +468,32 @@ func (it *IPTables) startReaper(interval time.Duration) { }() } -// CheckPackets runs pkts through the rules for hook and returns a map of packets that -// should not go forward. +// CheckOutputPackets performs the output hook on the packets. // -// Preconditions: -// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// * pkt.NetworkHeader is not nil. +// Returns a map of packets that must be dropped. // -// NOTE: unlike the Check API the returned map contains packets that should be -// dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckOutput(pkt, r, outNicName) + }) +} + +// CheckPostroutingPackets performs the postrouting hook on the packets. +// +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckPostrouting(pkt, r, addressEP, outNicName) + }) +} + +func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { + if ok := f(pkt); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -407,11 +513,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -428,7 +534,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -454,7 +560,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -477,16 +583,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) + return rule.Target.Action(pkt, hook, r, addressEP) } // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, &tcpip.ErrNotConnected{} } - return it.connections.originalDst(epID, netProto) + return it.connections.originalDst(epID, netProto, transProto) } diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go index 529e02a07..3d3c39c20 100644 --- a/pkg/tcpip/stack/iptables_state.go +++ b/pkg/tcpip/stack/iptables_state.go @@ -26,11 +26,15 @@ type unixTime struct { // saveLastUsed is invoked by stateify. func (cn *conn) saveLastUsed() unixTime { + cn.mu.Lock() + defer cn.mu.Unlock() return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} } // loadLastUsed is invoked by stateify. func (cn *conn) loadLastUsed(unix unixTime) { + cn.mu.Lock() + defer cn.mu.Unlock() cn.lastUsed = time.Unix(unix.second, unix.nano) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 96cc899bb..7e5a1672a 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,7 +79,7 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleReturn, 0 } @@ -97,7 +97,7 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -117,6 +117,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. + var address tcpip.Address switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { @@ -125,48 +126,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r address = header.IPv6Loopback } case Prerouting: - // No-op, as address is already set correctly. + // addressEP is expected to be set for the prerouting hook. + address = addressEP.MainAddress().Address default: panic("redirect target is supported only on output and prerouting hooks") } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - - if hook == Output { - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - udpHeader, - false, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - rt.Port, - address, - ) - } else { - udpHeader.SetDestinationPort(rt.Port) - } - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 - } - - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, r) - } - default: - return RuleDrop, 0 + if t := pkt.tuple; t != nil { + t.conn.performNAT(pkt, hook, r, rt.Port, address, true /* dnat */) + return RuleAccept, 0 } - return RuleAccept, 0 + return RuleDrop, 0 } // SNATTarget modifies the source port/IP in the outgoing packets. @@ -179,15 +150,7 @@ type SNATTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { - // Sanity check. - if st.NetworkProtocol != pkt.NetworkProtocolNumber { - panic(fmt.Sprintf( - "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", - st.NetworkProtocol, pkt.NetworkProtocolNumber)) - } - +func snatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -198,6 +161,33 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleDrop, 0 } + // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a + // different port. + if port == 0 { + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + port = header.UDP(pkt.TransportHeader().View()).SourcePort() + case header.TCPProtocolNumber: + port = header.TCP(pkt.TransportHeader().View()).SourcePort() + } + } + + if t := pkt.tuple; t != nil { + t.conn.performNAT(pkt, hook, r, port, address, false /* dnat */) + } + + return RuleAccept, 0 +} + +// Action implements Target.Action. +func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if st.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + st.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + switch hook { case Postrouting, Input: case Prerouting, Output, Forward: @@ -206,37 +196,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - header.UDP(pkt.TransportHeader().View()), - true, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - st.Port, - st.Addr, - ) - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 - } + return snatAction(pkt, hook, r, st.Port, st.Addr) +} - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) - } +// MasqueradeTarget modifies the source port/IP in the outgoing packets. +type MasqueradeTarget struct { + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if mt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + mt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Postrouting: + case Prerouting, Input, Forward, Output: + panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook)) default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + // addressEP is expected to be set for the postrouting hook. + ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */) + if ep == nil { + // No address exists that we can use as a source address. return RuleDrop, 0 } - return RuleAccept, 0 + address := ep.AddressWithPrefix().Address + ep.DecRef() + return snatAction(pkt, hook, r, 0 /* port */, address) } func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 66e5f22ac..b22024667 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -81,17 +81,6 @@ const ( // // +stateify savable type IPTables struct { - // mu protects v4Tables, v6Tables, and modified. - mu sync.RWMutex - // v4Tables and v6tables map tableIDs to tables. They hold builtin - // tables only, not user tables. mu must be locked for accessing. - v4Tables [NumTables]Table - v6Tables [NumTables]Table - // modified is whether tables have been modified at least once. It is - // used to elide the iptables performance overhead for workloads that - // don't utilize iptables. - modified bool - // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. @@ -101,6 +90,21 @@ type IPTables struct { // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} + + mu sync.RWMutex + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. + // + // +checklocks:mu + v4Tables [NumTables]Table + // +checklocks:mu + v6Tables [NumTables]Table + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + // + // +checklocks:mu + modified bool } // VisitTargets traverses all the targets of all tables and replaces each with @@ -352,5 +356,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 4d5431da1..40b33b6b5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Should get the address immediately since we should not have performed @@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + }, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Make sure the address does not resolve before the resolution time has @@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet @@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) { // Attempting to add the address again should not fail if the address's // state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } }) } @@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) { // Add addresses for each NIC. addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix1, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err) } addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix2, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err) } expectDADEvent(nicID2, addr2) addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix3, + } + if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err) } expectDADEvent(nicID3, addr3) @@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { continue } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: test.addrs[j].Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{} @@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr2, } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err) } // addr2 should be more preferred now since it is at the front of the primary // list. @@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr} + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err) } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) @@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) { } if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index b854d868c..29d580e76 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -72,9 +72,15 @@ type nic struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained packetEndpointList are - // not. - packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList + } + + packetEPs struct { + mu sync.RWMutex + + // eps is protected by the mutex, but the values contained in it are not. + // + // +checklocks:mu + eps map[tcpip.NetworkProtocolNumber]*packetEndpointList } } @@ -91,6 +97,8 @@ type packetEndpointList struct { mu sync.RWMutex // eps is protected by mu, but the contained PacketEndpoint values are not. + // + // +checklocks:mu eps []PacketEndpoint } @@ -111,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) { } } +func (p *packetEndpointList) len() int { + p.mu.RLock() + defer p.mu.RUnlock() + return len(p.eps) +} + // forEach calls fn with each endpoints in p while holding the read lock on p. func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { p.mu.RLock() @@ -143,18 +157,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), } nic.linkResQueue.init(nic) - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) + + nic.packetEPs.mu.Lock() + defer nic.packetEPs.mu.Unlock() + + nic.packetEPs.eps = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 - // Register supported packet and network endpoint protocols. - for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = new(packetEndpointList) - } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = new(packetEndpointList) - netEP := netProto.NewEndpoint(nic, nic) nic.networkEndpoints[netNum] = netEP @@ -365,6 +377,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt pkt.EgressRoute = r pkt.NetworkProtocolNumber = protocol + n.deliverOutboundPacket(r.RemoteLinkAddress, pkt) + if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil { return err } @@ -383,6 +397,7 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r pkt.NetworkProtocolNumber = protocol + n.deliverOutboundPacket(r.RemoteLinkAddress, pkt) } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) @@ -501,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return &tcpip.ErrUnknownProtocol{} @@ -512,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return &tcpip.ErrNotSupported{} } - addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -699,12 +714,9 @@ func (n *nic) isInGroup(addr tcpip.Address) bool { // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - n.mu.RLock() enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. if !enabled { - n.mu.RUnlock() - n.stats.disabledRx.packets.Increment() n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return @@ -715,7 +727,6 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { - n.mu.RUnlock() n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -727,44 +738,87 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 - // Are any packet type sockets listening for this network protocol? - protoEPs := n.mu.packetEPs[protocol] - // Other packet type sockets that are listening for all protocols. - anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] - n.mu.RUnlock() - // Deliver to interested packet endpoints without holding NIC lock. + var packetEPPkt *PacketBuffer deliverPacketEPs := func(ep PacketEndpoint) { - p := pkt.Clone() - p.PktType = tcpip.PacketHost - ep.HandlePacket(n.id, local, protocol, p) + if packetEPPkt == nil { + // Packet endpoints hold the full packet. + // + // We perform a deep copy because higher-level endpoints may point to + // the middle of a view that is held by a packet endpoint. Save/Restore + // does not support overlapping slices and will panic in this case. + // + // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports + // overlapping slices (e.g. by passing a shallow copy of pkt to the packet + // endpoint). + packetEPPkt = NewPacketBuffer(PacketBufferOptions{ + Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(), + }) + // If a link header was populated in the original packet buffer, then + // populate it in the packet buffer we provide to packet endpoints as + // packet endpoints inspect link headers. + packetEPPkt.LinkHeader().Consume(pkt.LinkHeader().View().Size()) + packetEPPkt.PktType = tcpip.PacketHost + } + + ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone()) } - if protoEPs != nil { + + n.packetEPs.mu.Lock() + // Are any packet type sockets listening for this network protocol? + protoEPs, protoEPsOK := n.packetEPs.eps[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs, anyEPsOK := n.packetEPs.eps[header.EthernetProtocolAll] + n.packetEPs.mu.Unlock() + + if protoEPsOK { protoEPs.forEach(deliverPacketEPs) } - if anyEPs != nil { + if anyEPsOK { anyEPs.forEach(deliverPacketEPs) } networkEndpoint.HandlePacket(pkt) } -// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. -func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - n.mu.RLock() +// deliverOutboundPacket delivers outgoing packets to interested endpoints. +func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) { + n.packetEPs.mu.RLock() + defer n.packetEPs.mu.RUnlock() // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - eps := n.mu.packetEPs[header.EthernetProtocolAll] - n.mu.RUnlock() + eps, ok := n.packetEPs.eps[header.EthernetProtocolAll] + if !ok { + return + } + + local := n.LinkAddress() + var packetEPPkt *PacketBuffer eps.forEach(func(ep PacketEndpoint) { - p := pkt.Clone() - p.PktType = tcpip.PacketOutgoing - // Add the link layer header as outgoing packets are intercepted - // before the link layer header is created. - n.LinkEndpoint.AddHeader(local, remote, protocol, p) - ep.HandlePacket(n.id, local, protocol, p) + if packetEPPkt == nil { + // Packet endpoints hold the full packet. + // + // We perform a deep copy because higher-level endpoints may point to + // the middle of a view that is held by a packet endpoint. Save/Restore + // does not support overlapping slices and will panic in this case. + // + // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports + // overlapping slices (e.g. by passing a shallow copy of pkt to the packet + // endpoint). + packetEPPkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: pkt.AvailableHeaderBytes(), + Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), + }) + // Add the link layer header as outgoing packets are intercepted before + // the link layer header is created and packet endpoints are interested + // in the link header. + n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt) + packetEPPkt.PktType = tcpip.PacketOutgoing + } + + ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone()) }) } @@ -917,12 +971,13 @@ func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigura } func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() + n.packetEPs.mu.Lock() + defer n.packetEPs.mu.Unlock() - eps, ok := n.mu.packetEPs[netProto] + eps, ok := n.packetEPs.eps[netProto] if !ok { - return &tcpip.ErrNotSupported{} + eps = new(packetEndpointList) + n.packetEPs.eps[netProto] = eps } eps.add(ep) @@ -930,14 +985,17 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa } func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { - n.mu.Lock() - defer n.mu.Unlock() + n.packetEPs.mu.Lock() + defer n.packetEPs.mu.Unlock() - eps, ok := n.mu.packetEPs[netProto] + eps, ok := n.packetEPs.eps[netProto] if !ok { return } eps.remove(ep) + if eps.len() == 0 { + delete(n.packetEPs.eps, netProto) + } } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 5cb342f78..c8ad93f29 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. -func (*testIPv6Protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements NetworkProtocol.ParseAddresses. func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9192d8433..888a8bd9d 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -143,6 +143,8 @@ type PacketBuffer struct { // NetworkPacketInfo holds an incoming packet's network-layer information. NetworkPacketInfo NetworkPacketInfo + + tuple *tuple } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -282,14 +284,12 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { return v } -// Clone makes a shallow copy of pk. -// -// Clone should be called in such cases so that no modifications is done to -// underlying packet payload. +// Clone makes a semi-deep copy of pk. The underlying packet payload is +// shared. Hence, no modifications is done to underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - buf: pk.buf, + buf: pk.buf.Clone(), reserved: pk.reserved, pushed: pk.pushed, consumed: pk.consumed, @@ -304,6 +304,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NICID: pk.NICID, RXTransportChecksumValidated: pk.RXTransportChecksumValidated, NetworkPacketInfo: pk.NetworkPacketInfo, + tuple: pk.tuple, } } @@ -321,25 +322,51 @@ func (pk *PacketBuffer) Network() header.Network { } } -// CloneToInbound makes a shallow copy of the packet buffer to be used as an -// inbound packet. +// CloneToInbound makes a semi-deep copy of the packet buffer (similar to +// Clone) to be used as an inbound packet. // // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { newPk := &PacketBuffer{ - buf: pk.buf, + buf: pk.buf.Clone(), // Treat unfilled header portion as reserved. reserved: pk.AvailableHeaderBytes(), + tuple: pk.tuple, + } + return newPk +} + +// DeepCopyForForwarding creates a deep copy of the packet buffer for +// forwarding. +// +// The returned packet buffer will have the network and transport headers +// set if the original packet buffer did. +func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer { + newPk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: reservedHeaderBytes, + Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(), + IsForwardedPacket: true, + }) + + { + consumeBytes := pk.NetworkHeader().View().Size() + if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed { + panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes)) + } + newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber } - // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to - // maintain this flag in the packet. Currently conntrack needs this flag to - // tell if a noop connection should be inserted at Input hook. Once conntrack - // redefines the manipulation field as mutable, we won't need the special noop - // connection. - if pk.NatDone { - newPk.NatDone = true + + { + consumeBytes := pk.TransportHeader().View().Size() + if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed { + panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes)) + } + newPk.TransportProtocolNumber = pk.TransportProtocolNumber } + + newPk.tuple = pk.tuple + return newPk } @@ -391,13 +418,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// DeleteFront removes count from the beginning of d. It panics if count > -// d.Size(). All backing storage references after the front of the d are -// invalidated. -func (d PacketData) DeleteFront(count int) { - if !d.pk.buf.Remove(d.pk.dataOffset(), count) { - panic("count > d.Size()") +// Consume is the same as PullUp except that is additionally consumes the +// returned bytes. Subsequent PullUp or Consume will not return these bytes. +func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) { + v, ok := d.PullUp(size) + if ok { + d.pk.consumed += size } + return v, ok } // CapLength reduces d to at most length bytes. diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index a8da34992..c376ed1a1 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -435,11 +435,17 @@ func TestPacketBufferData(t *testing.T) { } }) - // DeleteFront + // Consume. for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().DeleteFront(n) + v, ok := pkt.Data().Consume(n) + if !ok { + t.Fatalf("Consume failed") + } + if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) { + t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want) + } checkData(t, pkt, []byte(tc.data)[n:]) }) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index dfe2c886f..31b3a554d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int const ( // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. + // endpoint for new connections with no local address. CanBePrimaryEndpoint PrimaryEndpointBehavior = iota // FirstPrimaryEndpoint indicates the endpoint should be the first @@ -332,6 +331,19 @@ const ( NeverPrimaryEndpoint ) +func (peb PrimaryEndpointBehavior) String() string { + switch peb { + case CanBePrimaryEndpoint: + return "CanBePrimaryEndpoint" + case FirstPrimaryEndpoint: + return "FirstPrimaryEndpoint" + case NeverPrimaryEndpoint: + return "NeverPrimaryEndpoint" + default: + panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb)) + } +} + // AddressConfigType is the method used to add an address. type AddressConfigType int @@ -351,6 +363,14 @@ const ( AddressConfigSlaacTemp ) +// AddressProperties contains additional properties that can be configured when +// adding an address. +type AddressProperties struct { + PEB PrimaryEndpointBehavior + ConfigType AddressConfigType + Deprecated bool +} + // AssignableAddressEndpoint is a reference counted address endpoint that may be // assigned to a NetworkEndpoint. type AssignableAddressEndpoint interface { @@ -457,7 +477,7 @@ type AddressableEndpoint interface { // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. @@ -685,9 +705,6 @@ type NetworkProtocol interface { // than this targeted at this protocol. MinimumPacketSize() int - // DefaultPrefixLen returns the protocol's default prefix length. - DefaultPrefixLen() int - // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) @@ -733,16 +750,6 @@ type NetworkDispatcher interface { // // DeliverNetworkPacket takes ownership of pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) - - // DeliverOutboundPacket is called by link layer when a packet is being - // sent out. - // - // pkt.LinkHeader may or may not be set before calling - // DeliverOutboundPacket. Some packets do not have link headers (e.g. - // packets sent via loopback), and won't have the field set. - // - // DeliverOutboundPacket takes ownership of pkt. - DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -846,6 +853,14 @@ type LinkEndpoint interface { // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + + // WriteRawPacket writes a packet directly to the link. + // + // If the link-layer has its own header, the payload must already include the + // header. + // + // WriteRawPacket takes ownership of the packet. + WriteRawPacket(*PacketBuffer) tcpip.Error } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c73890c4c..428350f31 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -72,7 +72,8 @@ type Stack struct { // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. - rawFactory RawFactory + rawFactory RawFactory + packetEndpointWriteSupported bool demux *transportDemuxer @@ -119,8 +120,7 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // seed is a one-time random value initialized at stack startup - // and is used to seed the TCP port picking on active connections + // seed is a one-time random value initialized at stack startup. // // TODO(gvisor.dev/issue/940): S/R this field. seed uint32 @@ -161,6 +161,10 @@ type Stack struct { // This is required to prevent potential ACK loops. // Setting this to 0 will disable all rate limiting. tcpInvalidRateLimit time.Duration + + // tsOffsetSecret is the secret key for generating timestamp offsets + // initialized at stack startup. + tsOffsetSecret uint32 } // UniqueID is an abstract generator of unique identifiers. @@ -215,6 +219,10 @@ type Options struct { // this is non-nil. RawFactory RawFactory + // AllowPacketEndpointWrite determines if packet endpoints support write + // operations. + AllowPacketEndpointWrite bool + // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data // returned by the stack secure RNG. @@ -356,23 +364,24 @@ func New(opts Options) *Stack { opts.NUDConfigs.resetInvalidFields() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*nic), - defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), - seed: seed, - nudConfigs: opts.NUDConfigs, - uniqueIDGenerator: opts.UniqueID, - nudDisp: opts.NUDDisp, - randomGenerator: randomGenerator, - secureRNG: opts.SecureRNG, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + nics: make(map[tcpip.NICID]*nic), + packetEndpointWriteSupported: opts.AllowPacketEndpointWrite, + defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: opts.IPTables, + icmpRateLimiter: NewICMPRateLimiter(clock), + seed: seed, + nudConfigs: opts.NUDConfigs, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + randomGenerator: randomGenerator, + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -384,6 +393,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, tcpInvalidRateLimit: defaultTCPInvalidRateLimit, + tsOffsetSecret: randomGenerator.Uint32(), } // Add specified network protocols. @@ -906,46 +916,9 @@ type NICStateFlags struct { Loopback bool } -// AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { - return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) -} - -// AddAddressWithPrefix is the same as AddAddress, but allows you to specify -// the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { - ap := tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: addr, - } - return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) -} - -// AddProtocolAddress adds a new network-layer protocol address to the -// specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { - return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) -} - -// AddAddressWithOptions is the same as AddAddress, but allows you to specify -// whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return &tcpip.ErrUnknownProtocol{} - } - return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb) -} - -// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows -// you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +// AddProtocolAddress adds an address to the specified NIC, possibly with extra +// properties. +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -954,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return &tcpip.ErrUnknownNICID{} } - return nic.addAddress(protocolAddress, peb) + return nic.addAddress(protocolAddress, properties) } // RemoveAddress removes an existing network-layer address from the specified @@ -1649,9 +1622,27 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, ReserveHeaderBytes: int(nic.MaxHeaderLength()), Data: payload, }) + pkt.NetworkProtocolNumber = netProto return nic.WritePacketToRemote(remote, netProto, pkt) } +// WriteRawPacket writes data directly to the specified NIC without adding any +// headers. +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + if !ok { + return &tcpip.ErrUnknownNICID{} + } + + pkt := NewPacketBuffer(PacketBufferOptions{ + Data: payload, + }) + pkt.NetworkProtocolNumber = proto + return nic.WriteRawPacket(pkt) +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1819,8 +1810,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol return nic.setNUDConfigs(proto, c) } -// Seed returns a 32 bit value that can be used as a seed value for port -// picking, ISN generation etc. +// Seed returns a 32 bit value that can be used as a seed value. // // NOTE: The seed is generated once during stack initialization only. func (s *Stack) Seed() uint32 { @@ -1944,3 +1934,9 @@ func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProto return false } + +// PacketEndpointWriteSupported returns true iff packet endpoints support write +// operations. +func (s *Stack) PacketEndpointWriteSupported() bool { + return s.packetEndpointWriteSupported +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3089c0ef4..c23e91702 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().Consume(fakeNetHeaderLen) if !ok { return } - // DeleteFront invalidates slices. Make a copy before trimming. - nb := append([]byte(nil), hdr...) - pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), // Nothing checks the error. nil, /* transport error */ pkt, @@ -234,10 +231,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } -func (*fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { return f.packetCount[int(intfAddr)%len(f.packetCount)] } @@ -349,12 +342,26 @@ func TestNetworkReceive(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -517,8 +524,15 @@ func TestNetworkSend(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Make sure that the link-layer endpoint received the outbound packet. @@ -538,12 +552,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -551,12 +579,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -812,8 +854,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } ep2 := channel.New(1, defaultMTU, "") @@ -821,8 +870,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr2, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } // Set a route table that sends all packets with odd destination @@ -978,12 +1034,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -991,12 +1061,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -1058,8 +1142,15 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1108,8 +1199,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1242,8 +1340,15 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1270,8 +1375,8 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1310,8 +1415,8 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1453,8 +1558,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) @@ -1510,8 +1622,15 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { @@ -1633,8 +1752,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) + if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { @@ -1678,13 +1797,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { t.Fatalf("CreateNIC failed: %s", err) } nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) + if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) + if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err) } // Set the initial route table. @@ -1726,7 +1845,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1}, }, rt..., ) @@ -1808,8 +1927,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) } - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: anyAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { @@ -1886,22 +2012,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Add an address and in case of a primary one include a // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) + properties := stack.AddressProperties{PEB: behavior} if behavior == stack.CanBePrimaryEndpoint { protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, + Protocol: fakeNetNumber, + AddressWithPrefix: address.WithPrefix(), } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } // Remember the address/prefix. primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } } } @@ -1996,8 +2127,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { PrefixLen: tc.prefixLen, }, } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) + if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -2047,33 +2178,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } } -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ @@ -2084,96 +2188,43 @@ func TestAddProtocolAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - addrLenRange := []int{4, 16} behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) + configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp} + deprecatedRange := []bool{false, true} + wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange)) var addrGen addressGenerator for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) + for _, configType := range configTypeRange { + for _, deprecated := range deprecatedRange { + address := addrGen.next(addrLen) + properties := stack.AddressProperties{ + PEB: behavior, + ConfigType: configType, + Deprecated: deprecated, + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err) + } + wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, + }) } - expectedAddresses = append(expectedAddresses, protocolAddress) } } } gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) + verifyAddresses(t, wantAddresses, gotAddresses) } func TestCreateNICWithOptions(t *testing.T) { @@ -2290,8 +2341,15 @@ func TestNICStats(t *testing.T) { if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed: ", err) } - if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: nic.addr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err) } { @@ -2735,8 +2793,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // be returned by a call to GetMainNICAddress; // else, it should. const address1 = tcpip.Address("\x01") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + properties := stack.AddressProperties{PEB: pi} + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err) } addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) if err != nil { @@ -2785,16 +2851,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. const address3 = tcpip.Address("\x03") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) - + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address3, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err) } // Add back the address we removed earlier and // make sure the new peb was respected. // (The address should just be promoted now). - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: ps} + if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { @@ -3096,8 +3177,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: a.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -3203,8 +3288,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // The NIC should have joined addr1's solicited node multicast address. @@ -3359,8 +3448,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { PrefixLen: 128, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } // Address should be in the list of all addresses. @@ -3687,8 +3776,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) @@ -3750,8 +3839,8 @@ func TestResolveWith(t *testing.T) { PrefixLen: 24, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) @@ -3792,8 +3881,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -3881,8 +3977,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -3990,44 +4086,44 @@ func TestFindRouteWithForwarding(t *testing.T) { ) type netCfg struct { - proto tcpip.NetworkProtocolNumber - factory stack.NetworkProtocolFactory - nic1Addr tcpip.Address - nic2Addr tcpip.Address - remoteAddr tcpip.Address + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1AddrWithPrefix tcpip.AddressWithPrefix + nic2AddrWithPrefix tcpip.AddressWithPrefix + remoteAddr tcpip.Address } fakeNetCfg := netCfg{ - proto: fakeNetNumber, - factory: fakeNetFactory, - nic1Addr: nic1Addr, - nic2Addr: nic2Addr, - remoteAddr: remoteAddr, + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen}, + nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen}, + remoteAddr: remoteAddr, } globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: llAddr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: globalIPv6Addr1, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: llAddr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: globalIPv6Addr1, } ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: llAddr1, - remoteAddr: llAddr2, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: llAddr1.WithPrefix(), + remoteAddr: llAddr2, } ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", } tests := []struct { @@ -4036,8 +4132,8 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg netCfg forwardingEnabled bool - addrNIC tcpip.NICID - localAddr tcpip.Address + addrNIC tcpip.NICID + localAddrWithPrefix tcpip.AddressWithPrefix findRouteErr tcpip.Error dependentOnForwarding bool @@ -4047,7 +4143,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4056,7 +4152,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4065,7 +4161,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4074,7 +4170,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4083,7 +4179,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4092,7 +4188,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4101,7 +4197,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4110,7 +4206,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4118,7 +4214,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4126,7 +4222,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4134,7 +4230,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4142,7 +4238,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: true, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4166,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on different NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4174,7 +4270,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4182,7 +4278,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4190,7 +4286,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4198,7 +4294,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4206,7 +4302,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4214,7 +4310,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4222,7 +4318,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4230,7 +4326,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4238,7 +4334,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4246,7 +4342,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4268,12 +4364,20 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) } - if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic1AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } - if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic2AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { @@ -4282,20 +4386,20 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) - r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { - t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { return } - if r.LocalAddress() != test.localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr) + if r.LocalAddress() != test.localAddrWithPrefix.Address { + t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address) } if r.RemoteAddress() != test.netCfg.remoteAddr { t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr) @@ -4318,8 +4422,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if !ok { t.Fatal("packet not sent through ep2") } - if pkt.Route.LocalAddress != test.localAddr { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address) } if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index 90a8ba6cf..a941091b0 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/internal/tcp" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -288,6 +289,12 @@ type TCPSenderState struct { // RACKState holds the state related to RACK loss detection algorithm. RACKState TCPRACKState + + // RetransmitTS records the timestamp used to detect spurious recovery. + RetransmitTS uint32 + + // SpuriousRecovery indicates if the sender entered recovery spuriously. + SpuriousRecovery bool } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. @@ -386,6 +393,12 @@ type TCPSndBufState struct { // SndMTU is the smallest MTU seen in the control packets received. SndMTU int + + // AutoTuneSndBufDisabled indicates that the auto tuning of send buffer + // is disabled. + // + // Must be accessed using atomic operations. + AutoTuneSndBufDisabled uint32 } // TCPEndpointStateInner contains the members of TCPEndpointState used directly @@ -396,7 +409,7 @@ type TCPSndBufState struct { type TCPEndpointStateInner struct { // TSOffset is a randomized offset added to the value of the TSVal // field in the timestamp option. - TSOffset uint32 + TSOffset tcp.TSOffset // SACKPermitted is set to true if the peer sends the TCPSACKPermitted // option in the SYN/SYN-ACK. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index dda57e225..542d9257c 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -32,11 +32,13 @@ type protocolIDs struct { // transportEndpoints manages all endpoints of a given protocol. It has its own // mutex so as to reduce interference between protocols. type transportEndpoints struct { - // mu protects all fields of the transportEndpoints. - mu sync.RWMutex + mu sync.RWMutex + // +checklocks:mu endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. + // + // +checklocks:mu rawEndpoints []RawTransportEndpoint } @@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { // descending order of match quality. If a call to yield returns false, // iterEndpointsLocked stops iteration and returns immediately. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in // descending order of match quality. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { var matchedEPs []*endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) [] // findEndpointLocked returns the endpoint that most closely matches the given id. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { var matchedEP *endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo } type endpointsByNIC struct { - mu sync.RWMutex - endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 + + mu sync.RWMutex + // +checklocks:mu + endpoints map[tcpip.NICID]*multiPortEndpoint } func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { @@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet return true } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNIC.seed) + transEP := mpep.selectEndpoint(id, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() @@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) + mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` demux *transportDemuxer netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber + flags ports.FlagCounter + + mu sync.RWMutex `state:"nosave"` // endpoints stores the transport endpoints in the order in which they // were bound. This is required for UDP SO_REUSEADDR. + // + // +checklocks:mu endpoints []TransportEndpoint - flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpoints) == 1 { - return mpep.endpoints[0] +func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint { + ep.mu.RLock() + defer ep.mu.RUnlock() + + if len(ep.endpoints) == 1 { + return ep.endpoints[0] } - if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { - return mpep.endpoints[len(mpep.endpoints)-1] + if ep.flags.SharedFlags().ToFlags().Effective().MostRecent { + return ep.endpoints[len(ep.endpoints)-1] } payload := []byte{ @@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) - return mpep.endpoints[idx] + idx := reciprocalScale(hash, uint32(len(ep.endpoints))) + return ep.endpoints[idx] } func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { @@ -479,7 +489,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: d.stack.Seed(), + seed: d.stack.seed, } } if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { @@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN } } - ep := selectEndpoint(id, mpep, epsByNIC.seed) + ep := mpep.selectEndpoint(id, epsByNIC.seed) epsByNIC.mu.RUnlock() return ep } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 45b09110d..cd3a8c25a 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -35,7 +35,7 @@ import ( const ( testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") testSrcAddrV4 = "\x0a\x00\x00\x01" testDstAddrV4 = "\x0a\x00\x00\x02" @@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { - t.Fatalf("AddAddress IPv4 failed: %s", err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { - t.Fatalf("AddAddress IPv6 failed: %s", err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: testDstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err) } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 839178809..655931715 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -357,8 +357,15 @@ func TestTransportReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -428,8 +435,15 @@ func TestTransportControlReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -497,8 +511,15 @@ func TestTransportSend(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 55683b4fb..460a6afaf 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -19,7 +19,7 @@ // The starting point is the creation and configuration of a stack. A stack can // be created by calling the New() function of the tcpip/stack/stack package; // configuring a stack involves creating NICs (via calls to Stack.CreateNIC()), -// adding network addresses (via calls to Stack.AddAddress()), and +// adding network addresses (via calls to Stack.AddProtocolAddress()), and // setting a route table (via a call to Stack.SetRouteTable()). // // Once a stack is configured, endpoints can be created by calling @@ -423,9 +423,9 @@ type ControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packet used to create - // the read data was received. - Timestamp int64 + // Timestamp is the time that the last packet used to create the read data + // was received. + Timestamp time.Time `state:".(int64)"` // HasInq indicates whether Inq is valid/set. HasInq bool @@ -451,6 +451,12 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set. + HasIPv6PacketInfo bool + + // IPv6PacketInfo holds interface and address data on an incoming packet. + IPv6PacketInfo IPv6PacketInfo + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is // set. HasOriginalDstAddress bool @@ -465,10 +471,10 @@ type ControlMessages struct { // PacketOwner is used to get UID and GID of the packet. type PacketOwner interface { - // UID returns KUID of the packet. + // KUID returns KUID of the packet. KUID() uint32 - // GID returns KGID of the packet. + // KGID returns KGID of the packet. KGID() uint32 } @@ -1164,6 +1170,14 @@ type IPPacketInfo struct { DestinationAddr Address } +// IPv6PacketInfo is the message structure for IPV6_PKTINFO. +// +// +stateify savable +type IPv6PacketInfo struct { + Addr Address + NIC NICID +} + // SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to // get/set the default, min and max send buffer sizes. type SendBufferSizeOption struct { @@ -1231,11 +1245,11 @@ type Route struct { // String implements the fmt.Stringer interface. func (r Route) String() string { var out strings.Builder - fmt.Fprintf(&out, "%s", r.Destination) + _, _ = fmt.Fprintf(&out, "%s", r.Destination) if len(r.Gateway) > 0 { - fmt.Fprintf(&out, " via %s", r.Gateway) + _, _ = fmt.Fprintf(&out, " via %s", r.Gateway) } - fmt.Fprintf(&out, " nic %d", r.NIC) + _, _ = fmt.Fprintf(&out, " nic %d", r.NIC) return out.String() } @@ -1255,6 +1269,8 @@ type TransportProtocolNumber uint32 type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. +// +// +stateify savable type StatCounter struct { count atomicbitops.AlignedAtomicUint64 } @@ -1270,7 +1286,7 @@ func (s *StatCounter) Decrement() { } // Value returns the current value of the counter. -func (s *StatCounter) Value(name ...string) uint64 { +func (s *StatCounter) Value(...string) uint64 { return s.count.Load() } @@ -1849,6 +1865,10 @@ type TCPStats struct { // SegmentsAckedWithDSACK is the number of segments acknowledged with // DSACK. SegmentsAckedWithDSACK *StatCounter + + // SpuriousRecovery is the number of times the connection entered loss + // recovery spuriously. + SpuriousRecovery *StatCounter } // UDPStats collects UDP-specific stats. @@ -1981,6 +2001,8 @@ type Stats struct { } // ReceiveErrors collects packet receive errors within transport endpoint. +// +// +stateify savable type ReceiveErrors struct { // ReceiveBufferOverflow is the number of received packets dropped // due to the receive buffer being full. @@ -1998,8 +2020,10 @@ type ReceiveErrors struct { ChecksumErrors StatCounter } -// SendErrors collects packet send errors within the transport layer for -// an endpoint. +// SendErrors collects packet send errors within the transport layer for an +// endpoint. +// +// +stateify savable type SendErrors struct { // SendToNetworkFailed is the number of packets failed to be written to // the network endpoint. @@ -2010,6 +2034,8 @@ type SendErrors struct { } // ReadErrors collects segment read errors from an endpoint read call. +// +// +stateify savable type ReadErrors struct { // ReadClosed is the number of received packet drops because the endpoint // was shutdown for read. @@ -2025,6 +2051,8 @@ type ReadErrors struct { } // WriteErrors collects packet write errors from an endpoint write call. +// +// +stateify savable type WriteErrors struct { // WriteClosed is the number of packet drops because the endpoint // was shutdown for write. @@ -2040,6 +2068,8 @@ type WriteErrors struct { } // TransportEndpointStats collects statistics about the endpoint. +// +// +stateify savable type TransportEndpointStats struct { // PacketsReceived is the number of successful packet receives. PacketsReceived StatCounter diff --git a/pkg/tcpip/tcpip_state.go b/pkg/tcpip/tcpip_state.go new file mode 100644 index 000000000..1953e24a1 --- /dev/null +++ b/pkg/tcpip/tcpip_state.go @@ -0,0 +1,27 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "time" +) + +func (c *ControlMessages) saveTimestamp() int64 { + return c.Timestamp.UnixNano() +} + +func (c *ControlMessages) loadTimestamp(nsec int64) { + c.Timestamp = time.Unix(0, nsec) +} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 181ef799e..7c998eaae 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -34,12 +34,16 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 92fa6257d..6e1d4720d 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr, } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr, + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) { addr: utils.RouterNIC2IPv6Addr, }, } { - if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) + if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err) } } diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index f9ab7d0af..f01e2b128 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -15,19 +15,24 @@ package iptables_test import ( + "bytes" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) type inputIfNameMatcher struct { @@ -49,10 +54,10 @@ const ( nicName = "nic1" anotherNicName = "nic2" linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - srcAddrV4 = "\x0a\x00\x00\x01" - dstAddrV4 = "\x0a\x00\x00\x02" - srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01") + dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02") + srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") payloadSize = 20 ) @@ -66,8 +71,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: dstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, e } @@ -82,8 +91,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: dstAddrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } return s, e } @@ -601,11 +614,19 @@ func TestIPTableWritePackets(t *testing.T) { if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: srcAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) + } + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: srcAddrV4.WithPrefix(), } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) + if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) } s.SetRouteTable([]tcpip.Route{ @@ -856,11 +877,19 @@ func TestForwardingHook(t *testing.T) { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(), } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -1037,22 +1066,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) { if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err) } - if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err) } e2 := channel.New(1, header.IPv6MinimumMTU, "") if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) } - if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err) } - if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err) } if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -1132,3 +1161,312 @@ func TestInputHookWithLocalForwarding(t *testing.T) { }) } } + +func TestSNAT(t *testing.T) { + const listenPort = 8080 + + type endpointAndAddresses struct { + serverEP tcpip.Endpoint + serverAddr tcpip.Address + serverReadableCH chan struct{} + + clientEP tcpip.Endpoint + clientAddr tcpip.Address + clientReadableCH chan struct{} + + nattedClientAddr tcpip.Address + } + + newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + t.Cleanup(func() { + wq.EventUnregister(&we) + }) + + ep, err := s.NewEndpoint(transProto, netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) + } + t.Cleanup(ep.Close) + + return ep, ch + } + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses + }{ + { + name: "IPv4 host1 server with host2 client", + netProto: ipv4.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + + nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, + } + }, + }, + { + name: "IPv6 host1 server with host2 client", + netProto: ipv6.ProtocolNumber, + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + + nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, + } + }, + }, + } + + subTests := []struct { + name string + proto tcpip.TransportProtocolNumber + expectedConnectErr tcpip.Error + setupServer func(t *testing.T, ep tcpip.Endpoint) + setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + needRemoteAddr bool + }{ + { + name: "UDP", + proto: udp.ProtocolNumber, + expectedConnectErr: nil, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + if err := ep.Connect(clientAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) + } + return nil, nil + }, + needRemoteAddr: true, + }, + { + name: "TCP", + proto: tcp.ProtocolNumber, + expectedConnectErr: &tcpip.ErrConnectStarted{}, + setupServer: func(t *testing.T, ep tcpip.Endpoint) { + t.Helper() + + if err := ep.Listen(1); err != nil { + t.Fatalf("ep.Listen(1): %s", err) + } + }, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + var addr tcpip.FullAddress + for { + newEP, wq, err := ep.Accept(&addr) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Accept(_): %s", err) + } + if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( + "NIC", + )); diff != "" { + t.Errorf("accepted address mismatch (-want +got):\n%s", diff) + } + + we, newCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + return newEP, newCH + } + }, + needRemoteAddr: false, + }, + } + + setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, target stack.Target) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + filter := ipt.GetTable(stack.NATID, ipv6) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} + filter.Rules[ruleIdx].Target = target + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } + } + + natTypes := []struct { + name string + setupNAT func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber, tcpip.Address) + }{ + { + name: "SNAT", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) { + t.Helper() + + setupNAT(t, s, netProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: natToAddr}) + }, + }, + { + name: "Masquerade", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) { + t.Helper() + + setupNAT(t, s, netProto, &stack.MasqueradeTarget{NetworkProtocol: netProto}) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + for _, natType := range natTypes { + t.Run(natType.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) + + natType.setupNAT(t, routerStack, test.netProto, epsAndAddrs.nattedClientAddr) + + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } + } + nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr} + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + nattedClientAddr.Port = addr.Port + } + + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch + } + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + var r bytes.Reader + r.Reset(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(&r, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } + } + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + var buf bytes.Buffer + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } + + { + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + read(serverCH, serverEP, data, nattedClientAddr) + } + + { + data := []byte{5, 6, 7, 8, 9, 10, 11, 12} + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) + } + }) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 27caa0c28..95ddd8ec3 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err) + if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err) + if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err) + if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err) + if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ @@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) { Protocol: test.networkProtocolNumber, AddressWithPrefix: test.incomingAddr, } - if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err) } // Set up endpoint through which we will attempt to forward packets. @@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) { Protocol: test.networkProtocolNumber, AddressWithPrefix: test.outgoingAddr, } - if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index b2008f0b2..f33223e79 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) + if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ { @@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ { @@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) + if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ { @@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) { if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) + v4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: utils.Ipv4Addr, } - if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) + if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err) + } + v6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: utils.Ipv6Addr, + } + if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err) } if err := s.CreateNIC(nicID2, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ipv4Loopback, + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) + } + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: header.IPv6Loopback.WithPrefix(), } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) + if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) } if test.forwarding { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 2d0a6e6a7..7753e7d6e 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) } ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err) } // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote @@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } var wq waiter.Queue @@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) } // Set the route table so that UDP can find a NIC that is diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index ac3c703d4..422eb8408 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) { // request/reply packets. icmpDataOffset = 8 ) - ipv4Loopback := testutil.MustParse4("127.0.0.1") + ipv4Loopback := tcpip.AddressWithPrefix{ + Address: testutil.MustParse4("127.0.0.1"), + PrefixLen: 8, + } channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { @@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) { transProto tcpip.TransportProtocolNumber netProto tcpip.NetworkProtocolNumber linkEndpoint func() stack.LinkEndpoint - localAddr tcpip.Address + localAddr tcpip.AddressWithPrefix icmpBuf func(*testing.T) buffer.View expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) @@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, - localAddr: header.IPv6Loopback, + localAddr: header.IPv6Loopback.WithPrefix(), icmpBuf: ipv6ICMPBuf, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, @@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, - localAddr: utils.Ipv4Addr.Address, + localAddr: utils.Ipv4Addr, icmpBuf: ipv4ICMPBuf, checkLinkEndpoint: channelEPCheck, }, @@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, - localAddr: utils.Ipv6Addr.Address, + localAddr: utils.Ipv6Addr, icmpBuf: ipv6ICMPBuf, checkLinkEndpoint: channelEPCheck, }, @@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + if len(test.localAddr.Address) != 0 { + protocolAddr := tcpip.ProtocolAddress{ + Protocol: test.netProto, + AddressWithPrefix: test.localAddr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) { } defer ep.Close() - connAddr := tcpip.FullAddress{Addr: test.localAddr} + connAddr := tcpip.FullAddress{Addr: test.localAddr.Address} if err := ep.Connect(connAddr); err != test.expectedConnectErr { t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) } @@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) { if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { t.Errorf("received data mismatch (-want +got):\n%s", diff) } - if rr.RemoteAddr.Addr != test.localAddr { - t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) + if rr.RemoteAddr.Addr != test.localAddr.Address { + t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address) } test.checkLinkEndpoint(t, e) @@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) { } if subTest.addAddress { - if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err) } - if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 2e6ae55ea..c69410859 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -40,6 +40,14 @@ const ( Host2NICID = 4 ) +// Common NIC names used by tests. +const ( + Host1NICName = "host1NIC" + RouterNIC1Name = "routerNIC1" + RouterNIC2Name = "routerNIC2" + Host2NICName = "host2NIC" +) + // Common link addresses used by tests. const ( LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") @@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2) routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4) - if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err) + { + opts := stack.NICOptions{Name: Host1NICName} + if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil { + t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err) + { + opts := stack.NICOptions{Name: RouterNIC1Name} + if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err) + { + opts := stack.NICOptions{Name: RouterNIC2Name} + if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err) + } } - if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) + { + opts := stack.NICOptions{Name: Host2NICName} + if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil { + t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err) + } } if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { @@ -231,29 +251,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err) } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err) } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err) } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err) + if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err) } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err) } host1Stack.SetRouteTable([]tcpip.Route{ diff --git a/pkg/syserror/BUILD b/pkg/tcpip/transport/BUILD index 76bee5a64..af332ed91 100644 --- a/pkg/syserror/BUILD +++ b/pkg/tcpip/transport/BUILD @@ -3,8 +3,11 @@ load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( - name = "syserror", - srcs = ["syserror.go"], + name = "transport", + srcs = [ + "datagram.go", + "transport.go", + ], visibility = ["//visibility:public"], - deps = ["@org_golang_x_sys//unix:go_default_library"], + deps = ["//pkg/tcpip"], ) diff --git a/pkg/tcpip/transport/datagram.go b/pkg/tcpip/transport/datagram.go new file mode 100644 index 000000000..dfce72c69 --- /dev/null +++ b/pkg/tcpip/transport/datagram.go @@ -0,0 +1,49 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// DatagramEndpointState is the state of a datagram-based endpoint. +type DatagramEndpointState tcpip.EndpointState + +// The states a datagram-based endpoint may be in. +const ( + _ DatagramEndpointState = iota + DatagramEndpointStateInitial + DatagramEndpointStateBound + DatagramEndpointStateConnected + DatagramEndpointStateClosed +) + +// String implements fmt.Stringer. +func (s DatagramEndpointState) String() string { + switch s { + case DatagramEndpointStateInitial: + return "INITIAL" + case DatagramEndpointStateBound: + return "BOUND" + case DatagramEndpointStateConnected: + return "CONNECTED" + case DatagramEndpointStateClosed: + return "CLOSED" + default: + panic(fmt.Sprintf("unhandled %[1]T variant = %[1]d", s)) + } +} diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index bbc0e3ecc..4718ec4ec 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/waiter", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index f9a15efb2..31579a896 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,7 @@ package icmp import ( + "fmt" "io" "time" @@ -24,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -35,15 +38,6 @@ type icmpPacket struct { receivedAt time.Time `state:".(int64)"` } -type endpointState int - -const ( - stateInitial endpointState = iota - stateBound - stateConnected - stateClosed -) - // endpoint represents an ICMP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -51,14 +45,17 @@ const ( // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -70,38 +67,23 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - state endpointState - route *stack.Route `state:"manual"` - ttl uint8 - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool + ident uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, - state: stateInitial, uniqueID: s.UniqueID(), } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) + ep.net.Init(s, netProto, transProto, &ep.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -128,35 +110,40 @@ func (e *endpoint) Abort() { // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { - e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { - case stateBound, stateConnected: - bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) - } - - // Close the receive list and drain it. - e.rcvMu.Lock() - e.rcvClosed = true - e.rcvBufSize = 0 - for !e.rcvList.Empty() { - p := e.rcvList.Front() - e.rcvList.Remove(p) - } - e.rcvMu.Unlock() + notify := func() bool { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + return false + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + info := e.net.Info() + info.ID.LocalPort = e.ident + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice())) + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } - if e.route != nil { - e.route.Release() - e.route = nil - } + e.net.Shutdown() + e.net.Close() - // Update the state. - e.state = stateClosed + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } - e.mu.Unlock() + return true + }() - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + if notify { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + } } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. @@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -193,7 +180,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: p.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.receivedAt.UnixNano(), + Timestamp: p.receivedAt, }, } if opts.NeedRemoteAddr { @@ -214,13 +201,12 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // // Returns true for retry if preparation should be retried. // +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.state { - case stateInitial: - case stateConnected: +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - - case stateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { - return 0, err + return network.WriteContext{}, 0, err } if !retry { @@ -298,53 +272,35 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } } - route := e.route - if to != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := to.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - - dst, netProto, err := e.checkV4MappedLocked(*to) - if err != nil { - return 0, err - } - - // Find the endpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) - if err != nil { - return 0, err - } - defer r.Release() + ctx, err := e.net.AcquireContextForWrite(opts) + return ctx, e.ident, err +} - route = r +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + ctx, ident, err := e.prepareForWrite(opts) + if err != nil { + return 0, err } + defer ctx.Release() + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { return 0, &tcpip.ErrBadBuffer{} } - var err tcpip.Error - switch e.NetProto { + switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) + if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } case header.IPv6ProtocolNumber: - err = send6(route, e.ID.LocalPort, v, e.ttl) - } - - if err != nil { - return 0, err + if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } + default: + panic(fmt.Sprintf("unhandled network protocol = %d", netProto)) } return int64(len(v)), nil @@ -357,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool { return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt sets a socket option. -func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return nil +// SetSockOpt implements tcpip.Endpoint. +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + return e.net.SetSockOpt(opt) } -// SetSockOptInt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - } - return nil + return e.net.SetSockOptInt(opt, v) } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -387,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.rcvMu.Lock() - v := int(e.ttl) - e.rcvMu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +// GetSockOpt implements tcpip.Endpoint. +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { +func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength), }) - pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber @@ -426,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest - icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - pkt.Data().AppendView(data) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V4.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { +func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength), }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) @@ -468,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest pkt.Data().AppendView(data) dataRange := pkt.Data().AsRange() icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpv6, - Src: r.LocalAddress(), - Dst: r.RemoteAddress(), + Src: src, + Dst: dst, PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V6.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() + return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() tcpip.Error { return &tcpip.ErrNotSupported{} @@ -515,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - localPort := uint16(0) - switch e.state { - case stateInitial: - case stateBound, stateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.ident - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + nextID, err := e.registerWithStack(netProto, nextID) + if err != nil { + return err } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } - - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress(), - LocalPort: localPort, - RemoteAddress: r.RemoteAddress(), - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - id, err = e.registerWithStack(nicID, netProtos, id) + e.ident = nextID.LocalPort + return nil + }) if err != nil { - r.Release() return err } - e.ID = id - e.route = r - e.RegisterNICID = nicID - - e.state = stateConnected - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -585,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - e.shutdownFlags |= flags - if e.state != stateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } + + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } } if flags&tcpip.ShutdownRead != 0 { @@ -615,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) - return id, err + return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) + err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -644,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err := e.registerWithStack(boundNetProto, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, err = e.registerWithStack(addr.NIC, netProtos, id) + e.ident = id.LocalPort + return nil + }) if err != nil { return err } - e.ID = id - e.RegisterNICID = addr.NIC - - // Mark endpoint as bound. - e.state = stateBound - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -687,21 +567,24 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { return nil } +func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || + header.IsV4MulticastAddress(addr) || + header.IsV6MulticastAddress(addr) || + e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr) +} + // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - err := e.bindLocked(addr) - if err != nil { - return err + if len(addr.Addr) != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) { + return &tcpip.ErrBadLocalAddress{} } - e.BindNICID = addr.NIC - e.BindAddr = addr.Addr + e.mu.Lock() + defer e.mu.Unlock() - return nil + return e.bindLocked(addr) } // GetLocalAddress returns the address to which the endpoint is bound. @@ -709,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.ident + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -721,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { - return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} + if addr, connected := e.net.GetRemoteAddress(); connected { + return addr, nil } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness returns the current readiness of the endpoint. For example, if @@ -754,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. - switch e.NetProto { + switch e.net.NetProto() { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { @@ -828,9 +705,9 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + defer e.mu.RUnlock() + ret := e.net.Info() + ret.ID.LocalPort = e.ident return &ret } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index b8b839e4a..dfe453ff9 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,13 @@ package icmp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.thaw() + + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - if e.state != stateBound && e.state != stateConnected { - return - } - - var err tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + var err tcpip.Error + info := e.net.Info() + info.ID.LocalPort = e.ident + info.ID, err = e.registerWithStack(info.NetProto, info.ID) if err != nil { - panic(err) + panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err)) } - - e.ID.LocalAddress = e.route.LocalAddress() - } else if len(e.ID.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) - if err != nil { - panic(err) + e.ident = info.ID.LocalPort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go index cc950cbde..729f50e9a 100644 --- a/pkg/tcpip/transport/icmp/icmp_test.go +++ b/pkg/tcpip/transport/icmp/icmp_test.go @@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s t.Fatalf("s.CreateNIC(%d, _) = %s", id, err) } - if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addrV4.WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) } s.AddRoute(tcpip.Route{ diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD new file mode 100644 index 000000000..3818cb04e --- /dev/null +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -0,0 +1,46 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "network", + srcs = [ + "endpoint.go", + "endpoint_state.go", + ], + visibility = [ + "//pkg/tcpip/transport/icmp:__pkg__", + "//pkg/tcpip/transport/raw:__pkg__", + "//pkg/tcpip/transport/udp:__pkg__", + ], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + ], +) + +go_test( + name = "network_test", + size = "small", + srcs = ["endpoint_test.go"], + deps = [ + ":network", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/udp", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go new file mode 100644 index 000000000..e3094f59f --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -0,0 +1,811 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package network provides facilities to support tcpip.Endpoints that operate +// at the network layer or above. +package network + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" +) + +// Endpoint is a datagram-based endpoint. It only supports sending datagrams to +// a peer. +// +// +stateify savable +type Endpoint struct { + // The following fields must only be set once then never changed. + stack *stack.Stack `state:"manual"` + ops *tcpip.SocketOptions + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + wasBound bool + // owner is the owner of transmitted packets. + // + // +checklocks:mu + owner tcpip.PacketOwner + // +checklocks:mu + writeShutdown bool + // +checklocks:mu + effectiveNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu + connectedRoute *stack.Route `state:"manual"` + // +checklocks:mu + multicastMemberships map[multicastMembership]struct{} + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu + ttl uint8 + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu + multicastTTL uint8 + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu + multicastAddr tcpip.Address + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu + multicastNICID tcpip.NICID + // +checklocks:mu + ipv4TOS uint8 + // +checklocks:mu + ipv6TClass uint8 + + // Lock ordering: mu > infoMu. + infoMu sync.RWMutex `state:"nosave"` + // info has a dedicated mutex so that we can avoid lock ordering violations + // when reading the endpoint's info. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling Info() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setInfo. + // + // +checklocks:infoMu + info stack.TransportEndpointInfo + + // state holds a transport.DatagramBasedEndpointState. + // + // state must be accessed with atomics so that we can avoid lock ordering + // violations when reading the state. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling State() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setEndpointState. + // + // +checkatomics + state uint32 +} + +// +stateify savable +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + +// Init initializes the endpoint. +func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) { + e.mu.Lock() + memberships := e.multicastMemberships + e.mu.Unlock() + if memberships != nil { + panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships)) + } + + switch netProto { + case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber: + default: + panic(fmt.Sprintf("invalid protocol number = %d", netProto)) + } + + *e = Endpoint{ + stack: s, + ops: ops, + netProto: netProto, + transProto: transProto, + + info: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, + effectiveNetProto: netProto, + // Linux defaults to TTL=1. + multicastTTL: 1, + multicastMemberships: make(map[multicastMembership]struct{}), + } + + e.mu.Lock() + defer e.mu.Unlock() + e.setEndpointState(transport.DatagramEndpointStateInitial) +} + +// NetProto returns the network protocol the endpoint was initialized with. +func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { + return e.netProto +} + +// setEndpointState sets the state of the endpoint. +// +// e.mu must be held to synchronize changes to state with the rest of the +// endpoint. +// +// +checklocks:e.mu +func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { + atomic.StoreUint32(&e.state, uint32(state)) +} + +// State returns the state of the endpoint. +func (e *Endpoint) State() transport.DatagramEndpointState { + return transport.DatagramEndpointState(atomic.LoadUint32(&e.state)) +} + +// Close cleans the endpoint's resources and leaves the endpoint in a closed +// state. +func (e *Endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.State() == transport.DatagramEndpointStateClosed { + return + } + + for mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + + if e.connectedRoute != nil { + e.connectedRoute.Release() + e.connectedRoute = nil + } + + e.setEndpointState(transport.DatagramEndpointStateClosed) +} + +// SetOwner sets the owner of transmitted packets. +func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) { + e.mu.Lock() + defer e.mu.Unlock() + e.owner = owner +} + +func calculateTTL(route *stack.Route, ttl uint8, multicastTTL uint8) uint8 { + if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { + return multicastTTL + } + + if ttl == 0 { + return route.DefaultTTL() + } + + return ttl +} + +// WriteContext holds the context for a write. +type WriteContext struct { + transProto tcpip.TransportProtocolNumber + route *stack.Route + ttl uint8 + tos uint8 + owner tcpip.PacketOwner +} + +// Release releases held resources. +func (c *WriteContext) Release() { + c.route.Release() + *c = WriteContext{} +} + +// WritePacketInfo is the properties of a packet that may be written. +type WritePacketInfo struct { + NetProto tcpip.NetworkProtocolNumber + LocalAddress, RemoteAddress tcpip.Address + MaxHeaderLength uint16 + RequiresTXTransportChecksum bool +} + +// PacketInfo returns the properties of a packet that will be written. +func (c *WriteContext) PacketInfo() WritePacketInfo { + return WritePacketInfo{ + NetProto: c.route.NetProto(), + LocalAddress: c.route.LocalAddress(), + RemoteAddress: c.route.RemoteAddress(), + MaxHeaderLength: c.route.MaxHeaderLength(), + RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(), + } +} + +// WritePacket attempts to write the packet. +func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { + pkt.Owner = c.owner + + if headerIncluded { + return c.route.WriteHeaderIncludedPacket(pkt) + } + + return c.route.WritePacket(stack.NetworkHeaderParams{ + Protocol: c.transProto, + TTL: c.ttl, + TOS: c.tos, + }, pkt) +} + +// AcquireContextForWrite acquires a WriteContext. +func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. + if opts.More { + return WriteContext{}, &tcpip.ErrInvalidOptionValue{} + } + + if e.State() == transport.DatagramEndpointStateClosed { + return WriteContext{}, &tcpip.ErrInvalidEndpointState{} + } + + if e.writeShutdown { + return WriteContext{}, &tcpip.ErrClosedForSend{} + } + + route := e.connectedRoute + if opts.To == nil { + // If the user doesn't specify a destination, they should have + // connected to another address. + if e.State() != transport.DatagramEndpointStateConnected { + return WriteContext{}, &tcpip.ErrDestinationRequired{} + } + + route.Acquire() + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicID := opts.To.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } + info := e.Info() + if info.BindNICID != 0 { + if nicID != 0 && nicID != info.BindNICID { + return WriteContext{}, &tcpip.ErrNoRoute{} + } + + nicID = info.BindNICID + } + if nicID == 0 { + nicID = info.RegisterNICID + } + + dst, netProto, err := e.checkV4Mapped(*opts.To) + if err != nil { + return WriteContext{}, err + } + + route, _, err = e.connectRouteRLocked(nicID, dst, netProto) + if err != nil { + return WriteContext{}, err + } + } + + if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { + route.Release() + return WriteContext{}, &tcpip.ErrBroadcastDisabled{} + } + + var tos uint8 + switch netProto := route.NetProto(); netProto { + case header.IPv4ProtocolNumber: + tos = e.ipv4TOS + case header.IPv6ProtocolNumber: + tos = e.ipv6TClass + default: + panic(fmt.Sprintf("invalid protocol number = %d", netProto)) + } + + return WriteContext{ + transProto: e.transProto, + route: route, + ttl: calculateTTL(route, e.ttl, e.multicastTTL), + tos: tos, + owner: e.owner, + }, nil +} + +// Disconnect disconnects the endpoint from its peer. +func (e *Endpoint) Disconnect() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.State() != transport.DatagramEndpointStateConnected { + return + } + + info := e.Info() + // Exclude ephemerally bound endpoints. + if e.wasBound { + info.ID = stack.TransportEndpointID{ + LocalAddress: info.BindAddr, + } + e.setEndpointState(transport.DatagramEndpointStateBound) + } else { + info.ID = stack.TransportEndpointID{} + e.setEndpointState(transport.DatagramEndpointStateInitial) + } + e.setInfo(info) + + e.connectedRoute.Release() + e.connectedRoute = nil +} + +// connectRouteRLocked establishes a route to the specified interface or the +// configured multicast interface if no interface is specified and the +// specified address is a multicast address. +// +// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. +// +checklocks:e.mu +func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { + localAddr := e.Info().ID.LocalAddress + if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" + } + + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { + if nicID == 0 { + nicID = e.multicastNICID + } + if localAddr == "" && nicID == 0 { + localAddr = e.multicastAddr + } + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) + if err != nil { + return nil, 0, err + } + return r, nicID, nil +} + +// Connect connects the endpoint to the address. +func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error { + return nil + }) +} + +// ConnectAndThen connects the endpoint to the address and then calls the +// provided function. +// +// If the function returns an error, the endpoint's state does not change. The +// function will be called with the network protocol used to connect to the peer +// and the source and destination addresses that will be used to send traffic to +// the peer. +func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error { + addr.Port = 0 + + e.mu.Lock() + defer e.mu.Unlock() + + info := e.Info() + nicID := addr.NIC + switch e.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + if info.BindNICID == 0 { + break + } + + if nicID != 0 && nicID != info.BindNICID { + return &tcpip.ErrInvalidEndpointState{} + } + + nicID = info.BindNICID + default: + return &tcpip.ErrInvalidEndpointState{} + } + + addr, netProto, err := e.checkV4Mapped(addr) + if err != nil { + return err + } + + r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto) + if err != nil { + return err + } + + id := stack.TransportEndpointID{ + LocalAddress: info.ID.LocalAddress, + RemoteAddress: r.RemoteAddress(), + } + if e.State() == transport.DatagramEndpointStateInitial { + id.LocalAddress = r.LocalAddress() + } + + if err := f(r.NetProto(), info.ID, id); err != nil { + return err + } + + if e.connectedRoute != nil { + // If the endpoint was previously connected then release any previous route. + e.connectedRoute.Release() + } + e.connectedRoute = r + info.ID = id + info.RegisterNICID = nicID + e.setInfo(info) + e.effectiveNetProto = netProto + e.setEndpointState(transport.DatagramEndpointStateConnected) + return nil +} + +// Shutdown shutsdown the endpoint. +func (e *Endpoint) Shutdown() tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + e.writeShutdown = true + return nil + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } +} + +// checkV4MappedRLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { + info := e.Info() + unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) + if err != nil { + return tcpip.FullAddress{}, 0, err + } + return unwrapped, netProto, nil +} + +func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) +} + +// Bind binds the endpoint to the address. +func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { + return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error { + return nil + }) +} + +// BindAndThen binds the endpoint to the address and then calls the provided +// function. +// +// If the function returns an error, the endpoint's state does not change. The +// function will be called with the bound network protocol and address. +func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error { + addr.Port = 0 + + e.mu.Lock() + defer e.mu.Unlock() + + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.State() != transport.DatagramEndpointStateInitial { + return &tcpip.ErrInvalidEndpointState{} + } + + addr, netProto, err := e.checkV4Mapped(addr) + if err != nil { + return err + } + + nicID := addr.NIC + if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { + nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr) + if nicID == 0 { + return &tcpip.ErrBadLocalAddress{} + } + } + + if err := f(netProto, addr.Addr); err != nil { + return err + } + + e.wasBound = true + + info := e.Info() + info.ID = stack.TransportEndpointID{ + LocalAddress: addr.Addr, + } + info.BindNICID = addr.NIC + info.RegisterNICID = nicID + info.BindAddr = addr.Addr + e.setInfo(info) + e.effectiveNetProto = netProto + e.setEndpointState(transport.DatagramEndpointStateBound) + return nil +} + +// WasBound returns true iff the endpoint was ever bound. +func (e *Endpoint) WasBound() bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.wasBound +} + +// GetLocalAddress returns the address that the endpoint is bound to. +func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { + e.mu.RLock() + defer e.mu.RUnlock() + + info := e.Info() + addr := info.BindAddr + if e.State() == transport.DatagramEndpointStateConnected { + addr = e.connectedRoute.LocalAddress() + } + + return tcpip.FullAddress{ + NIC: info.RegisterNICID, + Addr: addr, + } +} + +// GetRemoteAddress returns the address that the endpoint is connected to. +func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.State() != transport.DatagramEndpointStateConnected { + return tcpip.FullAddress{}, false + } + + return tcpip.FullAddress{ + Addr: e.connectedRoute.RemoteAddress(), + NIC: e.Info().RegisterNICID, + }, true +} + +// SetSockOptInt sets the socket option. +func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { + switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return &tcpip.ErrNotSupported{} + } + + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() + + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + case tcpip.IPv4TOSOption: + e.mu.Lock() + e.ipv4TOS = uint8(v) + e.mu.Unlock() + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.ipv6TClass = uint8(v) + e.mu.Unlock() + } + + return nil +} + +// GetSockOptInt returns the socket option. +func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { + switch opt { + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + + case tcpip.MulticastTTLOption: + e.mu.Lock() + v := int(e.multicastTTL) + e.mu.Unlock() + return v, nil + + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.mu.RLock() + v := int(e.ipv4TOS) + e.mu.RUnlock() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.ipv6TClass) + e.mu.RUnlock() + return v, nil + + default: + return -1, &tcpip.ErrUnknownProtocolOption{} + } +} + +// SetSockOpt sets the socket option. +func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + switch v := opt.(type) { + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + defer e.mu.Unlock() + + fa := tcpip.FullAddress{Addr: v.InterfaceAddr} + fa, netProto, err := e.checkV4Mapped(fa) + if err != nil { + return err + } + nic := v.NIC + addr := fa.Addr + + if nic == 0 && addr == "" { + e.multicastAddr = "" + e.multicastNICID = 0 + break + } + + if nic != 0 { + if !e.stack.CheckNIC(nic) { + return &tcpip.ErrBadLocalAddress{} + } + } else { + nic = e.stack.CheckLocalAddress(0, netProto, addr) + if nic == 0 { + return &tcpip.ErrBadLocalAddress{} + } + } + + if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic { + return &tcpip.ErrInvalidEndpointState{} + } + + e.multicastNICID = nic + e.multicastAddr = addr + + case *tcpip.AddMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return &tcpip.ErrInvalidOptionValue{} + } + + nicID := v.NIC + + if v.InterfaceAddr.Unspecified() { + if nicID == 0 { + if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return &tcpip.ErrUnknownDevice{} + } + + memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.multicastMemberships[memToInsert]; ok { + return &tcpip.ErrPortInUse{} + } + + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships[memToInsert] = struct{}{} + + case *tcpip.RemoveMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return &tcpip.ErrInvalidOptionValue{} + } + + nicID := v.NIC + if v.InterfaceAddr.Unspecified() { + if nicID == 0 { + if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return &tcpip.ErrUnknownDevice{} + } + + memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.multicastMemberships[memToRemove]; !ok { + return &tcpip.ErrBadLocalAddress{} + } + + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + delete(e.multicastMemberships, memToRemove) + + case *tcpip.SocketDetachFilterOption: + return nil + } + return nil +} + +// GetSockOpt returns the socket option. +func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + switch o := opt.(type) { + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + *o = tcpip.MulticastInterfaceOption{ + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, + } + e.mu.Unlock() + + default: + return &tcpip.ErrUnknownProtocolOption{} + } + return nil +} + +// Info returns a copy of the endpoint info. +func (e *Endpoint) Info() stack.TransportEndpointInfo { + e.infoMu.RLock() + defer e.infoMu.RUnlock() + return e.info +} + +// setInfo sets the endpoint's info. +// +// e.mu must be held to synchronize changes to info with the rest of the +// endpoint. +// +// +checklocks:e.mu +func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) { + e.infoMu.Lock() + defer e.infoMu.Unlock() + e.info = info +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go new file mode 100644 index 000000000..68bd1fbf6 --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -0,0 +1,58 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" +) + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *Endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + + e.stack = s + + for m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(fmt.Sprintf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err)) + } + } + + info := e.Info() + + switch state := e.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound: + if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) { + if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 { + panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress)) + } + } + case transport.DatagramEndpointStateConnected: + var err tcpip.Error + multicastLoop := e.ops.GetMulticastLoop() + e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) + if err != nil { + panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) + } + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go new file mode 100644 index 000000000..f263a9ea2 --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint_test.go @@ -0,0 +1,318 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network_test + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +var ( + ipv4NICAddr = testutil.MustParse4("1.2.3.4") + ipv6NICAddr = testutil.MustParse6("a::1") + ipv4RemoteAddr = testutil.MustParse4("6.7.8.9") + ipv6RemoteAddr = testutil.MustParse6("b::1") +) + +func TestEndpointStateTransitions(t *testing.T) { + const nicID = 1 + + data := buffer.View([]byte{1, 2, 4, 5}) + v4Checker := func(t *testing.T, b buffer.View) { + checker.IPv4(t, b, + checker.SrcAddr(ipv4NICAddr), + checker.DstAddr(ipv4RemoteAddr), + checker.IPPayload(data), + ) + } + + v6Checker := func(t *testing.T, b buffer.View) { + checker.IPv6(t, b, + checker.SrcAddr(ipv6NICAddr), + checker.DstAddr(ipv6RemoteAddr), + checker.IPPayload(data), + ) + } + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + expectedMaxHeaderLength uint16 + expectedNetProto tcpip.NetworkProtocolNumber + expectedLocalAddr tcpip.Address + bindAddr tcpip.Address + expectedBoundAddr tcpip.Address + remoteAddr tcpip.Address + expectedRemoteAddr tcpip.Address + checker func(*testing.T, buffer.View) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + expectedMaxHeaderLength: header.IPv4MaximumHeaderSize, + expectedNetProto: ipv4.ProtocolNumber, + expectedLocalAddr: ipv4NICAddr, + bindAddr: header.IPv4AllSystems, + expectedBoundAddr: header.IPv4AllSystems, + remoteAddr: ipv4RemoteAddr, + expectedRemoteAddr: ipv4RemoteAddr, + checker: v4Checker, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + expectedMaxHeaderLength: header.IPv6FixedHeaderSize, + expectedNetProto: ipv6.ProtocolNumber, + expectedLocalAddr: ipv6NICAddr, + bindAddr: header.IPv6AllNodesMulticastAddress, + expectedBoundAddr: header.IPv6AllNodesMulticastAddress, + remoteAddr: ipv6RemoteAddr, + expectedRemoteAddr: ipv6RemoteAddr, + checker: v6Checker, + }, + { + name: "IPv4-mapped-IPv6", + netProto: ipv6.ProtocolNumber, + expectedMaxHeaderLength: header.IPv4MaximumHeaderSize, + expectedNetProto: ipv4.ProtocolNumber, + expectedLocalAddr: ipv4NICAddr, + bindAddr: testutil.MustParse6("::ffff:e000:0001"), + expectedBoundAddr: header.IPv4AllSystems, + remoteAddr: testutil.MustParse6("::ffff:0607:0809"), + expectedRemoteAddr: ipv4RemoteAddr, + checker: v4Checker, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + ipv4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4NICAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err) + } + ipv6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: ipv6NICAddr.WithPrefix(), + } + + if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID}, + {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID}, + }) + + var ops tcpip.SocketOptions + var ep network.Endpoint + ep.Init(s, test.netProto, udp.ProtocolNumber, &ops) + defer ep.Close() + if state := ep.State(); state != transport.DatagramEndpointStateInitial { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial) + } + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + if state := ep.State(); state != transport.DatagramEndpointStateBound { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound) + } + if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" { + t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff) + } + if addr, connected := ep.GetRemoteAddress(); connected { + t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr) + } + + connectAddr := tcpip.FullAddress{Addr: test.remoteAddr} + if err := ep.Connect(connectAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", connectAddr, err) + } + if state := ep.State(); state != transport.DatagramEndpointStateConnected { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected) + } + if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" { + t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff) + } + if addr, connected := ep.GetRemoteAddress(); !connected { + t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr) + } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" { + t.Errorf("remote address mismatch (-want +got):\n%s", diff) + } + + ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("ep.AcquireContexForWrite({}): %s", err) + } + defer ctx.Release() + info := ctx.PacketInfo() + if diff := cmp.Diff(network.WritePacketInfo{ + NetProto: test.expectedNetProto, + LocalAddress: test.expectedLocalAddr, + RemoteAddress: test.expectedRemoteAddr, + MaxHeaderLength: test.expectedMaxHeaderLength, + RequiresTXTransportChecksum: true, + }, info); diff != "" { + t.Errorf("write packet info mismatch (-want +got):\n%s", diff) + } + if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(info.MaxHeaderLength), + Data: data.ToVectorisedView(), + }), false /* headerIncluded */); err != nil { + t.Fatalf("ctx.WritePacket(_, false): %s", err) + } + if pkt, ok := e.Read(); !ok { + t.Fatalf("expected packet to be read from link endpoint") + } else { + test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader())) + } + + ep.Close() + if state := ep.State(); state != transport.DatagramEndpointStateClosed { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed) + } + }) + } +} + +func TestBindNICID(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + bindAddr tcpip.Address + unicast bool + }{ + { + name: "IPv4 multicast", + netProto: ipv4.ProtocolNumber, + bindAddr: header.IPv4AllSystems, + unicast: false, + }, + { + name: "IPv6 multicast", + netProto: ipv6.ProtocolNumber, + bindAddr: header.IPv6AllNodesMulticastAddress, + unicast: false, + }, + { + name: "IPv4 unicast", + netProto: ipv4.ProtocolNumber, + bindAddr: ipv4NICAddr, + unicast: true, + }, + { + name: "IPv6 unicast", + netProto: ipv6.ProtocolNumber, + bindAddr: ipv6NICAddr, + unicast: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, testBindNICID := range []tcpip.NICID{0, nicID} { + t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + ipv4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4NICAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err) + } + ipv6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: ipv6NICAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err) + } + + var ops tcpip.SocketOptions + var ep network.Endpoint + ep.Init(s, test.netProto, udp.ProtocolNumber, &ops) + defer ep.Close() + if ep.WasBound() { + t.Fatal("got ep.WasBound() = true, want = false") + } + wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber} + if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" { + t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff) + } + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + if !ep.WasBound() { + t.Error("got ep.WasBound() = false, want = true") + } + wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr} + wantInfo.BindAddr = bindAddr.Addr + wantInfo.BindNICID = bindAddr.NIC + if test.unicast { + wantInfo.RegisterNICID = nicID + } else { + wantInfo.RegisterNICID = bindAddr.NIC + } + if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" { + t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 8e7bb6c6e..80eef39e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,7 +25,6 @@ package packet import ( - "fmt" "io" "time" @@ -60,52 +59,47 @@ type packet struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue cooked bool - - // The following fields are used to manage the receive queue and are - // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList + ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats + + // The following fields are used to manage the receive queue. + rcvMu sync.Mutex `state:"nosave"` + // +checklocks:rcvMu + rcvList packetList + // +checklocks:rcvMu rcvBufSize int - rcvClosed bool - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + // +checklocks:rcvMu + rcvClosed bool + // +checklocks:rcvMu + rcvDisabled bool + + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + closed bool + // +checklocks:mu + boundNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu boundNIC tcpip.NICID - // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` - lastError tcpip.Error - - // ops is used to get socket level options. - ops tcpip.SocketOptions - - // frozen indicates if the packets should be delivered to the endpoint - // during restore. - frozen bool + // +checklocks:lastErrorMu + lastError tcpip.Error } // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - }, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, + stack: s, + cooked: cooked, + boundNetProto: netProto, + waiterQueue: waiterQueue, } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) @@ -141,7 +135,7 @@ func (ep *endpoint) Close() { return } - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) ep.rcvMu.Lock() defer ep.rcvMu.Unlock() @@ -154,7 +148,6 @@ func (ep *endpoint) Close() { } ep.closed = true - ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -189,7 +182,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul Total: packet.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: packet.receivedAt.UnixNano(), + Timestamp: packet.receivedAt, }, } if opts.NeedRemoteAddr { @@ -207,8 +200,52 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul return res, nil } -func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { - return 0, &tcpip.ErrInvalidOptionValue{} +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + if !ep.stack.PacketEndpointWriteSupported() { + return 0, &tcpip.ErrNotSupported{} + } + + ep.mu.Lock() + closed := ep.closed + nicID := ep.boundNIC + proto := ep.boundNetProto + ep.mu.Unlock() + if closed { + return 0, &tcpip.ErrClosedForSend{} + } + + var remote tcpip.LinkAddress + if to := opts.To; to != nil { + remote = tcpip.LinkAddress(to.Addr) + + if n := to.NIC; n != 0 { + nicID = n + } + + if p := to.Port; p != 0 { + proto = tcpip.NetworkProtocolNumber(p) + } + } + + if nicID == 0 { + return 0, &tcpip.ErrInvalidOptionValue{} + } + + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. + payloadBytes := make(buffer.View, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, &tcpip.ErrBadBuffer{} + } + + if err := func() tcpip.Error { + if ep.cooked { + return ep.stack.WritePacketToRemote(nicID, remote, proto, payloadBytes.ToVectorisedView()) + } + return ep.stack.WriteRawPacket(nicID, proto, payloadBytes.ToVectorisedView()) + }(); err != nil { + return 0, err + } + return int64(len(payloadBytes)), nil } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be @@ -253,29 +290,42 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound && ep.boundNIC == addr.NIC { - // If the NIC being bound is the same then just return success. + netProto := tcpip.NetworkProtocolNumber(addr.Port) + if netProto == 0 { + // Do not allow unbinding the network protocol. + netProto = ep.boundNetProto + } + + if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto { + // Already bound to the requested NIC and network protocol. return nil } - // Unregister endpoint with all the nics. - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) - ep.bound = false + // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new + // binding. + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) + ep.boundNIC = 0 + ep.boundNetProto = 0 // Bind endpoint to receive packets from specific interface. - if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } - ep.bound = true ep.boundNIC = addr.NIC - + ep.boundNetProto = netProto return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + ep.mu.RLock() + defer ep.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: ep.boundNIC, + Port: uint16(ep.boundNetProto), + }, nil } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. @@ -359,7 +409,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, _ tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -371,7 +421,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } rcvBufSize := ep.ops.GetReceiveBufferSize() - if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { + if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -380,76 +430,39 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, wasEmpty := ep.rcvBufSize == 0 - // Push new packet into receive list and increment the buffer size. - var packet packet + rcvdPkt := packet{ + packetInfo: tcpip.LinkPacketInfo{ + Protocol: netProto, + PktType: pkt.PktType, + }, + senderAddr: tcpip.FullAddress{ + NIC: nicID, + }, + receivedAt: ep.stack.Clock().Now(), + } + if !pkt.LinkHeader().View().IsEmpty() { - // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader().View()) - packet.senderAddr = tcpip.FullAddress{ - NIC: nicID, - Addr: tcpip.Address(hdr.SourceAddress()), - } - packet.packetInfo.Protocol = netProto - packet.packetInfo.PktType = pkt.PktType - } else { - // Guess the would-be ethernet header. - packet.senderAddr = tcpip.FullAddress{ - NIC: nicID, - Addr: tcpip.Address(localAddr), - } - packet.packetInfo.Protocol = netProto - packet.packetInfo.PktType = pkt.PktType + rcvdPkt.senderAddr.Addr = tcpip.Address(hdr.SourceAddress()) } if ep.cooked { - // Cooked packets can simply be queued. - switch pkt.PktType { - case tcpip.PacketHost: - packet.data = pkt.Data().ExtractVV() - case tcpip.PacketOutgoing: - // Strip Link Header. - var combinedVV buffer.VectorisedView - if v := pkt.NetworkHeader().View(); !v.IsEmpty() { - combinedVV.AppendView(v) - } - if v := pkt.TransportHeader().View(); !v.IsEmpty() { - combinedVV.AppendView(v) - } - combinedVV.Append(pkt.Data().ExtractVV()) - packet.data = combinedVV - default: - panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + // Cooked packet endpoints don't include the link-headers in received + // packets. + if v := pkt.NetworkHeader().View(); !v.IsEmpty() { + rcvdPkt.data.AppendView(v) } - } else { - // Raw packets need their ethernet headers prepended before - // queueing. - var linkHeader buffer.View - if pkt.PktType != tcpip.PacketOutgoing { - if pkt.LinkHeader().View().IsEmpty() { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, - } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - linkHeader = buffer.View(fakeHeader) - } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) - } - combinedVV := linkHeader.ToVectorisedView() - combinedVV.Append(pkt.Data().ExtractVV()) - packet.data = combinedVV - } else { - packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + if v := pkt.TransportHeader().View(); !v.IsEmpty() { + rcvdPkt.data.AppendView(v) } + rcvdPkt.data.Append(pkt.Data().ExtractVV()) + } else { + // Raw packet endpoints include link-headers in received packets. + rcvdPkt.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } - packet.receivedAt = ep.stack.Clock().Now() - ep.rcvList.PushBack(&packet) - ep.rcvBufSize += packet.data.Size() + ep.rcvList.PushBack(&rcvdPkt) + ep.rcvBufSize += rcvdPkt.data.Size() ep.rcvMu.Unlock() ep.stats.PacketsReceived.Increment() @@ -467,10 +480,8 @@ func (*endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() - // Make a copy of the endpoint info. - ret := ep.TransportEndpointInfo - ep.mu.RUnlock() - return &ret + defer ep.mu.RUnlock() + return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto} } // Stats returns a pointer to the endpoint stats. @@ -485,18 +496,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } - -// freeze prevents any more packets from being delivered to the endpoint. -func (ep *endpoint) freeze() { - ep.mu.Lock() - ep.frozen = true - ep.mu.Unlock() -} - -// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows -// new packets to be delivered again. -func (ep *endpoint) thaw() { - ep.mu.Lock() - ep.frozen = false - ep.mu.Unlock() -} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index e729921db..88cd80ad3 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -34,33 +35,34 @@ func (p *packet) loadReceivedAt(nsec int64) { // saveData saves packet.data field. func (p *packet) saveData() buffer.VectorisedView { - // We cannot save p.data directly as p.data.views may alias to p.views, - // which is not allowed by state framework (in-struct pointer). return p.data.Clone(nil) } // loadData loads packet.data field. func (p *packet) loadData(data buffer.VectorisedView) { - // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization - // here because data.views is not guaranteed to be loaded by now. Plus, - // data.views will be allocated anyway so there really is little point - // of utilizing p.views for data.views. p.data = data } // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - ep.freeze() + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + ep.rcvDisabled = true } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - ep.thaw() + ep.mu.Lock() + defer ep.mu.Unlock() + ep.stack = stack.StackFromEnv ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. - if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(err) + if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil { + panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err)) } + + ep.rcvMu.Lock() + ep.rcvDisabled = false + ep.rcvMu.Unlock() } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 2eab09088..b7e97e218 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/packet", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b3d8951ff..181b478d0 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,6 +26,7 @@ package raw import ( + "fmt" "io" "time" @@ -34,6 +35,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -57,15 +60,19 @@ type rawPacket struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue associated bool + net network.Endpoint + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions + // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` @@ -74,20 +81,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - closed bool - connected bool - bound bool - // route is the route to a remote network endpoint. It is set via - // Connect(), and is valid only when conneted is true. - route *stack.Route `state:"manual"` - stats tcpip.TransportEndpointStats `state:"nosave"` - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - + mu sync.RWMutex `state:"nosave"` // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool @@ -99,16 +93,9 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { - if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { - return nil, &tcpip.ErrUnknownProtocol{} - } - e := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, associated: associated, } @@ -116,6 +103,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) e.ops.SetReceiveBufferSize(32*1024, false /* notify */) + e.net.Init(s, netProto, transProto, &e.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -137,7 +125,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { return nil, err } @@ -154,11 +142,17 @@ func (e *endpoint) Close() { e.mu.Lock() defer e.mu.Unlock() - if e.closed || !e.associated { + if e.net.State() == transport.DatagramEndpointStateClosed { return } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) + e.net.Close() + + if !e.associated { + return + } + + e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e) e.rcvMu.Lock() defer e.rcvMu.Unlock() @@ -170,15 +164,6 @@ func (e *endpoint) Close() { e.rcvList.Remove(e.rcvList.Front()) } - e.connected = false - - if e.route != nil { - e.route.Release() - e.route = nil - } - - e.closed = true - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -186,9 +171,7 @@ func (e *endpoint) Close() { func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.mu.Lock() - defer e.mu.Unlock() - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -219,7 +202,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: pkt.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: pkt.receivedAt.UnixNano(), + Timestamp: pkt.receivedAt, }, } if opts.NeedRemoteAddr { @@ -236,14 +219,15 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Write implements tcpip.Endpoint.Write. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + netProto := e.net.NetProto() // We can create, but not write to, unassociated IPv6 endpoints. - if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + if !e.associated && netProto == header.IPv6ProtocolNumber { return 0, &tcpip.ErrInvalidOptionValue{} } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. - if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { + if netProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { return 0, &tcpip.ErrInvalidOptionValue{} } } @@ -269,79 +253,25 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} + ctx, err := e.net.AcquireContextForWrite(opts) + if err != nil { + return 0, err } - payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - if e.closed { - return nil, nil, nil, &tcpip.ErrInvalidEndpointState{} - } - - payloadBytes := make([]byte, p.Len()) - if _, err := io.ReadFull(p, payloadBytes); err != nil { - return nil, nil, nil, &tcpip.ErrBadBuffer{} - } - - // Did the user caller provide a destination? If not, use the connected - // destination. - if opts.To == nil { - // If the user doesn't specify a destination, they should have - // connected to another address. - if !e.connected { - return nil, nil, nil, &tcpip.ErrDestinationRequired{} - } - e.route.Acquire() - - return payloadBytes, e.route, e.owner, nil - } - - // The caller provided a destination. Reject destination address if it - // goes through a different NIC than the endpoint was bound to. - nic := opts.To.NIC - if e.bound && nic != 0 && nic != e.BindNICID { - return nil, nil, nil, &tcpip.ErrNoRoute{} - } + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, &tcpip.ErrBadBuffer{} + } - // Find the route to the destination. If BindAddress is 0, - // FindRoute will choose an appropriate source address. - route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) - if err != nil { - return nil, nil, nil, err - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength), + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) - return payloadBytes, route, e.owner, nil - }() - if err != nil { + if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil { return 0, err } - defer route.Release() - - if e.ops.GetHeaderIncluded() { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.View(payloadBytes).ToVectorisedView(), - }) - if err := route.WriteHeaderIncludedPacket(pkt); err != nil { - return 0, err - } - } else { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(route.MaxHeaderLength()), - Data: buffer.View(payloadBytes).ToVectorisedView(), - }) - pkt.Owner = owner - if err := route.WritePacket(stack.NetworkHeaderParams{ - Protocol: e.TransProto, - TTL: route.DefaultTTL(), - TOS: stack.DefaultTOS, - }, pkt); err != nil { - return 0, err - } - } return int64(len(payloadBytes)), nil } @@ -353,66 +283,29 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + netProto := e.net.NetProto() + // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. - if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { + if netProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { return &tcpip.ErrAddressFamilyNotSupported{} } - e.mu.Lock() - defer e.mu.Unlock() - - if e.closed { - return &tcpip.ErrInvalidEndpointState{} - } - - nic := addr.NIC - if e.bound { - if e.BindNICID == 0 { - // If we're bound, but not to a specific NIC, the NIC - // in addr will be used. Nothing to do here. - } else if addr.NIC == 0 { - // If we're bound to a specific NIC, but addr doesn't - // specify a NIC, use the bound NIC. - nic = e.BindNICID - } else if addr.NIC != e.BindNICID { - // We're bound and addr specifies a NIC. They must be - // the same. - return &tcpip.ErrInvalidEndpointState{} - } - } - - // Find a route to the destination. - route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) - if err != nil { - return err - } - - if e.associated { - // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { - route.Release() - return err + return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error { + if e.associated { + // Re-register the endpoint with the appropriate NIC. + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { + return err + } + e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e) } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) - e.RegisterNICID = nic - } - - if e.route != nil { - // If the endpoint was previously connected then release any previous route. - e.route.Release() - } - e.route = route - e.connected = true - return nil + return nil + }) } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - if !e.connected { + if e.net.State() != transport.DatagramEndpointStateConnected { return &tcpip.ErrNotConnected{} } return nil @@ -430,46 +323,26 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi // Bind implements tcpip.Endpoint.Bind. func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - // If a local address was specified, verify that it's valid. - if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} - } + return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error { + if !e.associated { + return nil + } - if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { return err } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) - e.RegisterNICID = addr.NIC - e.BindNICID = addr.NIC - } - - e.BindAddr = addr.Addr - e.bound = true - - return nil + e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e) + return nil + }) } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - addr := e.BindAddr - if e.connected { - addr = e.route.LocalAddress() - } - - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: addr, - // Linux returns the protocol in the port field. - Port: uint16(e.TransProto), - }, nil + a := e.net.GetLocalAddress() + // Linux returns the protocol in the port field. + a.Port = uint16(e.transProto) + return a, nil } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. @@ -502,17 +375,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil default: - return &tcpip.ErrUnknownProtocolOption{} + return e.net.SetSockOpt(opt) } } -func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { + return e.net.SetSockOptInt(opt, v) } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -529,100 +402,108 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { return v, nil default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - e.mu.RLock() - e.rcvMu.Lock() + notifyReadableEvents := func() bool { + e.mu.RLock() + defer e.mu.RUnlock() + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return false + } - // Drop the packet if our buffer is currently full or if this is an unassociated - // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only - // See: https://man7.org/linux/man-pages/man7/raw.7.html - // - // An IPPROTO_RAW socket is send only. If you really want to receive - // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. - // Note that packet sockets don't reassemble IP fragments, unlike raw - // sockets. - if e.rcvClosed || !e.associated { - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stack.Stats().DroppedPackets.Increment() - e.stats.ReceiveErrors.ClosedReceiver.Increment() - return - } + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return false + } - rcvBufSize := e.ops.GetReceiveBufferSize() - if e.frozen || e.rcvBufSize >= int(rcvBufSize) { - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stack.Stats().DroppedPackets.Increment() - e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() - return - } + srcAddr := pkt.Network().SourceAddress() + info := e.net.Info() - remoteAddr := pkt.Network().SourceAddress() + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: + // If connected, only accept packets from the remote address we + // connected to. + if info.ID.RemoteAddress != srcAddr { + return false + } - if e.bound { - // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != pkt.NICID { - e.rcvMu.Unlock() - e.mu.RUnlock() - return - } - // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != remoteAddr { - e.rcvMu.Unlock() - e.mu.RUnlock() - return + // Connected sockets may also have been bound to a specific + // address/NIC. + fallthrough + case transport.DatagramEndpointStateBound: + // If bound to a NIC, only accept data for that NIC. + if info.BindNICID != 0 && info.BindNICID != pkt.NICID { + return false + } + + // If bound to an address, only accept data for that address. + if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() { + return false + } + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } - } - // If connected, only accept packets from the remote address we - // connected to. - if e.connected && e.route.RemoteAddress() != remoteAddr { - e.rcvMu.Unlock() - e.mu.RUnlock() - return - } + wasEmpty := e.rcvBufSize == 0 - wasEmpty := e.rcvBufSize == 0 + // Push new packet into receive list and increment the buffer size. + packet := &rawPacket{ + senderAddr: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: srcAddr, + }, + } - // Push new packet into receive list and increment the buffer size. - packet := &rawPacket{ - senderAddr: tcpip.FullAddress{ - NIC: pkt.NICID, - Addr: remoteAddr, - }, - } + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + // + // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports + // overlapping slices. + var combinedVV buffer.VectorisedView + if info.NetProto == header.IPv4ProtocolNumber { + networkHeader, transportHeader := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(networkHeader)+len(transportHeader)) + headers = append(headers, networkHeader...) + headers = append(headers, transportHeader...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() + } + combinedVV.Append(pkt.Data().ExtractVV()) + packet.data = combinedVV + packet.receivedAt = e.stack.Clock().Now() - // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. - // We copy headers' underlying bytes because pkt.*Header may point to - // the middle of a slice, and another struct may point to the "outer" - // slice. Save/restore doesn't support overlapping slices and will fail. - var combinedVV buffer.VectorisedView - if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - headers := make(buffer.View, 0, len(network)+len(transport)) - headers = append(headers, network...) - headers = append(headers, transport...) - combinedVV = headers.ToVectorisedView() - } else { - combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() - } - combinedVV.Append(pkt.Data().ExtractVV()) - packet.data = combinedVV - packet.receivedAt = e.stack.Clock().Now() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + e.stats.PacketsReceived.Increment() - e.rcvList.PushBack(packet) - e.rcvBufSize += packet.data.Size() - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stats.PacketsReceived.Increment() - // Notify waiters that there's data to be read. - if wasEmpty { + // Notify waiters that there is data to be read now. + return wasEmpty + }() + + if notifyReadableEvents { e.waiterQueue.Notify(waiter.ReadableEvents) } } @@ -634,10 +515,7 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { - e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + ret := e.net.Info() return &ret } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 39669b445..e74713064 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,6 +15,7 @@ package raw import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -60,35 +61,16 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.net.Resume(s) + e.thaw() e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - // If the endpoint is connected, re-connect. - if e.connected { - var err tcpip.Error - // TODO(gvisor.dev/issue/4906): Properly restore the route with the right - // remote address. We used to pass e.remote.RemoteAddress which was - // effectively the empty address but since moving e.route to hold a pointer - // to a route instead of the route by value, we pass the empty address - // directly. Obviously this was always wrong since we should provide the - // remote address we were connected to, to properly restore the route. - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false) - if err != nil { - panic(err) - } - } - - // If the endpoint is bound, re-bind. - if e.bound { - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - if e.associated { - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { - panic(err) + netProto := e.net.NetProto() + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { + panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err)) } } } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 8436d2cf0..20958d882 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -68,6 +68,7 @@ go_library( "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", + "//pkg/tcpip/internal/tcp", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", @@ -79,9 +80,10 @@ go_library( go_test( name = "tcp_x_test", - size = "medium", + size = "large", srcs = [ "dual_stack_test.go", + "rcv_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", "tcp_rack_test.go", @@ -96,6 +98,7 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", @@ -112,16 +115,6 @@ go_test( ) go_test( - name = "rcv_test", - size = "small", - srcs = ["rcv_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( name = "tcp_test", size = "small", srcs = [ diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index aa413ad05..caf14b0dc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -15,12 +15,12 @@ package tcp import ( + "container/list" "crypto/sha1" "encoding/binary" "fmt" "hash" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -72,7 +72,8 @@ func encodeMSS(mss uint16) uint32 { // and must not be accessed or have its methods called concurrently as they // may mutate the stored objects. type listenContext struct { - stack *stack.Stack + stack *stack.Stack + protocol *protocol // rcvWnd is the receive window that is sent by this listening context // in the initial SYN-ACK. @@ -99,18 +100,6 @@ type listenContext struct { // netProto indicates the network protocol(IPv4/v6) for the listening // endpoint. netProto tcpip.NetworkProtocolNumber - - // pendingMu protects pendingEndpoints. This should only be accessed - // by the listening endpoint's worker goroutine. - // - // Lock Ordering: listenEP.workerMu -> pendingMu - pendingMu sync.Mutex - // pending is used to wait for all pendingEndpoints to finish when - // a socket is closed. - pending sync.WaitGroup - // pendingEndpoints is a map of all endpoints for which a handshake is - // in progress. - pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -119,15 +108,15 @@ func timeStamp(clock tcpip.Clock) uint32 { } // newListenContext creates a new listen context. -func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stk, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6Only: v6Only, - netProto: netProto, - listenEP: listenEP, - pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), + stack: stk, + protocol: protocol, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6Only: v6Only, + netProto: netProto, + listenEP: listenEP, } for i := range l.nonce { @@ -191,17 +180,9 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } -func (l *listenContext) useSynCookies() bool { - var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies - if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { - panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) - } - return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) -} - // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -213,7 +194,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header return nil, err } - n := newEndpoint(l.stack, netProto, queue) + n := newEndpoint(l.stack, l.protocol, netProto, queue) n.ops.SetV6Only(l.v6Only) n.TransportEndpointInfo.ID = s.id n.boundNICID = s.nicID @@ -244,10 +225,10 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header // On success, a handshake h is returned with h.ep.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { +func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed()) + isn := generateSecureISN(s.id, l.stack.Clock(), l.protocol.seqnumSecret) ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err @@ -271,18 +252,15 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q return nil, &tcpip.ErrConnectionAborted{} } - l.addPendingEndpoint(ep) // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. - l.listenEP.propagateInheritableOptionsLocked(ep) + l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce if !ep.reserveTupleLocked() { ep.mu.Unlock() ep.Close() - l.removePendingEndpoint(ep) - return nil, &tcpip.ErrConnectionAborted{} } @@ -301,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - ep.drainClosingSegmentQueue() return nil, err @@ -323,7 +297,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // established endpoint is returned with e.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) { +func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) { h, err := l.startHandshake(s, opts, queue, owner) if err != nil { return nil, err @@ -342,39 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, return ep, nil } -func (l *listenContext) addPendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - l.pendingEndpoints[n.TransportEndpointInfo.ID] = n - l.pending.Add(1) - l.pendingMu.Unlock() -} - -func (l *listenContext) removePendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) - l.pending.Done() - l.pendingMu.Unlock() -} - -func (l *listenContext) closeAllPendingEndpoints() { - l.pendingMu.Lock() - for _, n := range l.pendingEndpoints { - n.notifyProtocolGoroutine(notifyClose) - } - l.pendingMu.Unlock() - l.pending.Wait() -} - -// Precondition: h.ep.mu must be held. // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() e.Close() e.notifyAborted() - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.drainClosingSegmentQueue() e.h = nil } @@ -382,12 +329,9 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // cleanupCompletedHandshake transfers any state from the completed handshake to // the new endpoint. // -// Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.isConnectNotified = true // Update the receive window scaling. We can't do it before the @@ -399,47 +343,11 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e.h = nil } -// deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// listener has transitioned out of the listen state (accepted is the zero -// value), the new endpoint is reset instead. -func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { - e.mu.Lock() - e.pendingAccepted.Add(1) - e.mu.Unlock() - defer e.pendingAccepted.Done() - - // Drop the lock before notifying to avoid deadlock in user-specified - // callbacks. - delivered := func() bool { - e.acceptMu.Lock() - defer e.acceptMu.Unlock() - for { - if e.accepted == (accepted{}) { - return false - } - if e.accepted.endpoints.Len() == e.accepted.cap { - e.acceptCond.Wait() - continue - } - - e.accepted.endpoints.PushBack(n) - if !withSynCookie { - atomic.AddInt32(&e.synRcvdCount, -1) - } - return true - } - }() - if delivered { - e.waiterQueue.Notify(waiter.ReadableEvents) - } else { - n.notifyProtocolGoroutine(notifyReset) - } -} - // propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. // -// Precondition: e.mu and n.mu must be held. +// +checklocks:e.mu +// +checklocks:n.mu func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout n.portFlags = e.portFlags @@ -450,9 +358,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // reserveTupleLocked reserves an accepted endpoint's tuple. // -// Preconditions: -// * propagateInheritableOptionsLocked has been called. -// * e.mu is held. +// Precondition: e.propagateInheritableOptionsLocked has been called. +// +// +checklocks:e.mu func (e *endpoint) reserveTupleLocked() bool { dest := tcpip.FullAddress{ Addr: e.TransportEndpointInfo.ID.RemoteAddress, @@ -487,70 +395,36 @@ func (e *endpoint) notifyAborted() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -// handleSynSegment is called in its own goroutine once the listening endpoint -// receives a SYN segment. It is responsible for completing the handshake and -// queueing the new endpoint for acceptance. -// -// A limited number of these goroutines are allowed before TCP starts using SYN -// cookies to accept connections. -// -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error { - defer s.decRef() - - h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) - if err != nil { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - atomic.AddInt32(&e.synRcvdCount, -1) - return err - } +func (e *endpoint) acceptQueueIsFull() bool { + e.acceptMu.Lock() + full := e.acceptQueue.isFull() + e.acceptMu.Unlock() + return full +} - go func() { - // Note that startHandshake returns a locked endpoint. The - // force call here just makes it so. - if err := h.complete(); err != nil { // +checklocksforce - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - ctx.cleanupFailedHandshake(h) - atomic.AddInt32(&e.synRcvdCount, -1) - return - } - ctx.cleanupCompletedHandshake(h) - h.ep.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep, false /*withSynCookie*/) - }() +// +stateify savable +type acceptQueue struct { + // NB: this could be an endpointList, but ilist only permits endpoints to + // belong to one list at a time, and endpoints are already stored in the + // dispatcher's list. + endpoints list.List `state:".([]*endpoint)"` - return nil -} + // pendingEndpoints is a set of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[*endpoint]struct{} -func (e *endpoint) synRcvdBacklogFull() bool { - e.acceptMu.Lock() - acceptedCap := e.accepted.cap - e.acceptMu.Unlock() - // The capacity of the accepted queue would always be one greater than the - // listen backlog. But, the SYNRCVD connections count is always checked - // against the listen backlog value for Linux parity reason. - // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 - // - // We maintain an equality check here as the synRcvdCount is incremented - // and compared only from a single listener context and the capacity of - // the accepted queue can only increase by a new listen call. - return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 + // capacity is the maximum number of endpoints that can be in endpoints. + capacity int } -func (e *endpoint) acceptQueueIsFull() bool { - e.acceptMu.Lock() - full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap - e.acceptMu.Unlock() - return full +func (a *acceptQueue) isFull() bool { + return a.endpoints.Len() == a.capacity } // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed @@ -578,11 +452,95 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } opts := parseSynSegmentOptions(s) - if !ctx.useSynCookies() { - s.incRef() - atomic.AddInt32(&e.synRcvdCount, 1) - return e.handleSynSegment(ctx, s, &opts) + + useSynCookies, err := func() (bool, tcpip.Error) { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + if alwaysUseSynCookies { + return true, nil + } + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 { + return true, nil + } + + h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) + if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return false, err + } + + e.acceptQueue.pendingEndpoints[h.ep] = struct{}{} + e.pendingAccepted.Add(1) + + go func() { + defer func() { + e.pendingAccepted.Done() + + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + delete(e.acceptQueue.pendingEndpoints, h.ep) + }() + + // Note that startHandshake returns a locked endpoint. The force call + // here just makes it so. + if err := h.complete(); err != nil { // +checklocksforce + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + ctx.cleanupFailedHandshake(h) + return + } + ctx.cleanupCompletedHandshake(h) + h.ep.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + + // Deliver the endpoint to the accept queue. + // + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + // The listener is transitioning out of the Listen state; bail. + if e.acceptQueue.capacity == 0 { + return false + } + if e.acceptQueue.isFull() { + e.acceptCond.Wait() + continue + } + + e.acceptQueue.endpoints.PushBack(h.ep) + return true + } + }() + + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + h.ep.notifyProtocolGoroutine(notifyReset) + } + }() + + return false, nil + }() + if err != nil { + return err } + if !useSynCookies { + return nil + } + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err @@ -600,10 +558,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())), TSEcr: opts.TSVal, MSS: calculateAdvertisedMSS(e.userMSS, route), } + if opts.TS { + offset := e.protocol.tsOffset(s.dstAddr, s.srcAddr) + now := e.stack.Clock().NowMonotonic() + synOpts.TSVal = offset.TSVal(now) + } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) fields := tcpFields{ id: s.id, @@ -621,18 +583,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil case s.flags.Contains(header.TCPFlagAck): - if e.acceptQueueIsFull() { - // Silently drop the ack as the application can't accept - // the connection at this point. The ack will be - // retransmitted by the sender anyway and we can - // complete the connection at the time of retransmit if - // the backlog has space. - e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() - e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() - e.stack.Stats().DroppedPackets.Increment() - return nil - } - iss := s.ackNumber - 1 irs := s.sequenceNumber - 1 @@ -668,9 +618,27 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // ACK was received from the sender. return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } + + // Keep hold of acceptMu until the new endpoint is in the accept queue (or + // if there is an error), to guarantee that we will keep our spot in the + // queue even if another handshake from the syn queue completes. + e.acceptMu.Lock() + if e.acceptQueue.isFull() { + // Silently drop the ack as the application can't accept + // the connection at this point. The ack will be + // retransmitted by the sender anyway and we can + // complete the connection at the time of retransmit if + // the backlog has space. + e.acceptMu.Unlock() + e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return nil + } + e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. - rcvdSynOptions := &header.TCPSynOptions{ + rcvdSynOptions := header.TCPSynOptions{ MSS: mssTable[data], // Disable Window scaling as original SYN is // lost. @@ -689,6 +657,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { + e.acceptMu.Unlock() return err } @@ -700,6 +669,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !n.reserveTupleLocked() { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -717,6 +687,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.boundBindToDevice, ); err != nil { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -725,25 +696,22 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } n.isRegistered = true - - // clear the tsOffset for the newly created - // endpoint as the Timestamp was already - // randomly offset when the original SYN-ACK was - // sent above. - n.TSOffset = 0 + n.TSOffset = n.protocol.tsOffset(s.dstAddr, s.srcAddr) // Switch state to connected. n.isConnectNotified = true - n.transitionToStateEstablishedLocked(&handshake{ - ep: n, - iss: iss, - ackNum: irs + 1, - rcvWnd: seqnum.Size(n.initialReceiveWindow()), - sndWnd: s.window, - rcvWndScale: e.rcvWndScaleForHandshake(), - sndWndScale: rcvdSynOptions.WS, - mss: rcvdSynOptions.MSS, - }) + h := handshake{ + ep: n, + iss: iss, + ackNum: irs + 1, + rcvWnd: seqnum.Size(n.initialReceiveWindow()), + sndWnd: s.window, + rcvWndScale: e.rcvWndScaleForHandshake(), + sndWndScale: rcvdSynOptions.WS, + mss: rcvdSynOptions.MSS, + sampleRTTWithTSOnly: true, + } + h.transitionToStateEstablishedLocked(s) // Requeue the segment if the ACK completing the handshake has more info // to be procesed by the newly established endpoint. @@ -752,20 +720,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.newSegmentWaker.Assert() } - // Do the delivery in a separate goroutine so - // that we don't block the listen loop in case - // the application is slow to accept or stops - // accepting. - // - // NOTE: This won't result in an unbounded - // number of goroutines as we do check before - // entering here that there was at least some - // space available in the backlog. - // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n, true /*withSynCookie*/) + + // Deliver the endpoint to the accept queue. + e.acceptQueue.endpoints.PushBack(n) + e.acceptMu.Unlock() + + e.waiterQueue.Notify(waiter.ReadableEvents) return nil default: @@ -779,17 +742,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() v6Only := e.ops.GetV6Only() - ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) + ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto) defer func() { - // Mark endpoint as closed. This will prevent goroutines running - // handleSynSegment() from attempting to queue new connections - // to the endpoint. e.setEndpointState(StateClose) - // Close any endpoints in SYN-RCVD state. - ctx.closeAllPendingEndpoints() - // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 93ed161f9..80cd07218 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -30,6 +30,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// InitialRTO is the initial retransmission timeout. +// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142 +const InitialRTO = time.Second + // maxSegmentsPerWake is the maximum number of segments to process in the main // protocol goroutine per wake-up. Yielding [after this number of segments are // processed] allows other events to be processed as well (e.g., timeouts, @@ -105,6 +109,11 @@ type handshake struct { // sendSYNOpts is the cached values for the SYN options to be sent. sendSYNOpts header.TCPSynOptions + + // sampleRTTWithTSOnly is true when the segment was retransmitted or we can't + // tell; then RTT can only be sampled when the incoming segment has timestamp + // options enabled. + sampleRTTWithTSOnly bool } func (e *endpoint) newHandshake() *handshake { @@ -117,10 +126,12 @@ func (e *endpoint) newHandshake() *handshake { h.resetState() // Store reference to handshake state in endpoint. e.h = h + // By the time handshake is created, e.ID is already initialized. + e.TSOffset = e.protocol.tsOffset(e.ID.LocalAddress, e.ID.RemoteAddress) return h } -func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake { +func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) *handshake { h := e.newHandshake() h.resetToSynRcvd(isn, irs, opts, deferAccept) return h @@ -150,20 +161,23 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.protocol.seqnumSecret) } // generateSecureISN generates a secure Initial Sequence number based on the // recommendation here https://tools.ietf.org/html/rfc6528#page-3. func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value { isnHasher := jenkins.Sum32(seed) - isnHasher.Write([]byte(id.LocalAddress)) - isnHasher.Write([]byte(id.RemoteAddress)) + // Per hash.Hash.Writer: + // + // It never returns an error. + _, _ = isnHasher.Write([]byte(id.LocalAddress)) + _, _ = isnHasher.Write([]byte(id.RemoteAddress)) portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, id.LocalPort) - isnHasher.Write(portBuf) + _, _ = isnHasher.Write(portBuf) binary.LittleEndian.PutUint16(portBuf, id.RemotePort) - isnHasher.Write(portBuf) + _, _ = isnHasher.Write(portBuf) // The time period here is 64ns. This is similar to what linux uses // generate a sequence number that overlaps less than one // time per MSL (2 minutes). @@ -190,7 +204,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -251,10 +265,10 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { rcvSynOpts := parseSynSegmentOptions(s) // Remember if the Timestamp option was negotiated. - h.ep.maybeEnableTimestamp(&rcvSynOpts) + h.ep.maybeEnableTimestamp(rcvSynOpts) // Remember if the SACKPermitted option was negotiated. - h.ep.maybeEnableSACKPermitted(&rcvSynOpts) + h.ep.maybeEnableSACKPermitted(rcvSynOpts) // Remember the sequence we'll ack from now on. h.ackNum = s.sequenceNumber + 1 @@ -266,8 +280,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // and the handshake is completed. if s.flags.Contains(header.TCPFlagAck) { h.state = handshakeCompleted - - h.ep.transitionToStateEstablishedLocked(h) + h.transitionToStateEstablishedLocked(s) h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil @@ -283,7 +296,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, - TSVal: h.ep.timestamp(), + TSVal: h.ep.tsValNow(), TSEcr: h.ep.recentTimestamp(), // We only send SACKPermitted if the other side indicated it @@ -353,7 +366,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: h.ep.SendTSOk, - TSVal: h.ep.timestamp(), + TSVal: h.ep.tsValNow(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.SACKPermitted, MSS: h.ep.amss, @@ -402,9 +415,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { if h.ep.SendTSOk && s.parsedOptions.TS { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } + h.state = handshakeCompleted - h.ep.transitionToStateEstablishedLocked(h) + h.transitionToStateEstablishedLocked(s) // Requeue the segment if the ACK completing the handshake has more info // to be procesed by the newly established endpoint. @@ -480,7 +494,7 @@ func (h *handshake) start() { synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, - TSVal: h.ep.timestamp(), + TSVal: h.ep.tsValNow(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, @@ -522,7 +536,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -557,6 +571,10 @@ func (h *handshake) complete() tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, h.sendSYNOpts) + // If we have ever retransmitted the SYN-ACK or + // SYN segment, we should only measure RTT if + // TS option is present. + h.sampleRTTWithTSOnly = true } case wakerForNotification: @@ -564,6 +582,9 @@ func (h *handshake) complete() tcpip.Error { if (n¬ifyClose)|(n¬ifyAbort) != 0 { return &tcpip.ErrAborted{} } + if n¬ifyShutdown != 0 { + return &tcpip.ErrConnectionReset{} + } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { s := h.ep.segmentQueue.dequeue() @@ -600,6 +621,40 @@ func (h *handshake) complete() tcpip.Error { return nil } +// transitionToStateEstablisedLocked transitions the endpoint of the handshake +// to an established state given the last segment received from peer. It also +// initializes sender/receiver. +func (h *handshake) transitionToStateEstablishedLocked(s *segment) { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + h.ep.snd = newSender(h.ep, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + + now := h.ep.stack.Clock().NowMonotonic() + + var rtt time.Duration + if h.ep.SendTSOk && s.parsedOptions.TSEcr != 0 { + rtt = h.ep.elapsed(now, s.parsedOptions.TSEcr) + } + if !h.sampleRTTWithTSOnly && rtt == 0 { + rtt = now.Sub(h.startTime) + } + + if rtt > 0 { + h.ep.snd.updateRTO(rtt) + } + + h.ep.rcvQueueInfo.rcvQueueMu.Lock() + h.ep.rcv = newReceiver(h.ep, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + h.ep.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd) + h.ep.rcvQueueInfo.rcvQueueMu.Unlock() + + h.ep.setEndpointState(StateEstablished) +} + type backoffTimer struct { timeout time.Duration maxTimeout time.Duration @@ -873,7 +928,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) + offset += header.EncodeTSOption(e.tsValNow(), e.recentTimestamp(), options[offset:]) } if e.SACKPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) @@ -965,26 +1020,6 @@ func (e *endpoint) completeWorkerLocked() { } } -// transitionToStateEstablisedLocked transitions a given endpoint -// to an established state using the handshake parameters provided. -// It also initializes sender/receiver. -func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - e.rcvQueueInfo.rcvQueueMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) - // Bootstrap the auto tuning algorithm. Starting at zero will - // result in a really large receive window after the first auto - // tuning adjustment. - e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd) - e.rcvQueueInfo.rcvQueueMu.Unlock() - - e.setEndpointState(StateEstablished) -} - // transitionToStateCloseLocked ensures that the endpoint is // cleaned up from the transport demuxer, "before" moving to // StateClose. This will ensure that no packet will be @@ -1286,7 +1321,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoopDone is called at the end of protocolMainLoop. // +checklocksrelease:e.mu -func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) { +func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer) { if e.snd != nil { e.snd.resendTimer.cleanup() e.snd.probeTimer.cleanup() @@ -1314,7 +1349,7 @@ func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *slee // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) { var ( closeTimer tcpip.Timer closeWaker sleep.Waker @@ -1331,8 +1366,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.hardError = err e.workerCleanup = true - e.protocolMainLoopDone(closeTimer, &closeWaker) - return err + e.protocolMainLoopDone(closeTimer) + return } } @@ -1559,8 +1594,8 @@ loop: // just want to terminate the loop and cleanup the // endpoint. cleanupOnError(nil) - e.protocolMainLoopDone(closeTimer, &closeWaker) - return nil + e.protocolMainLoopDone(closeTimer) + return case StateTimeWait: fallthrough case StateClose: @@ -1568,8 +1603,8 @@ loop: default: if err := funcs[v].f(); err != nil { cleanupOnError(err) - e.protocolMainLoopDone(closeTimer, &closeWaker) - return nil + e.protocolMainLoopDone(closeTimer) + return } } } @@ -1592,21 +1627,19 @@ loop: // Handle any StateError transition from StateTimeWait. if e.EndpointState() == StateError { cleanupOnError(nil) - e.protocolMainLoopDone(closeTimer, &closeWaker) - return nil + e.protocolMainLoopDone(closeTimer) + return } e.transitionToStateCloseLocked() - e.protocolMainLoopDone(closeTimer, &closeWaker) + e.protocolMainLoopDone(closeTimer) // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue if reuseTW != nil { reuseTW() } - - return nil } // handleTimeWaitSegments processes segments received during TIME_WAIT diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 044123185..6a798e980 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,12 +15,10 @@ package tcp import ( - "container/list" "encoding/binary" "fmt" "io" "math" - "math/rand" "runtime" "strings" "sync/atomic" @@ -188,6 +186,8 @@ const ( // say TIME_WAIT. notifyTickleWorker notifyError + // notifyShutdown means that a connecting socket was shutdown. + notifyShutdown ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -204,6 +204,8 @@ type SACKInfo struct { } // ReceiveErrors collect segment receive errors within transport layer. +// +// +stateify savable type ReceiveErrors struct { tcpip.ReceiveErrors @@ -233,6 +235,8 @@ type ReceiveErrors struct { } // SendErrors collect segment send errors within the transport layer. +// +// +stateify savable type SendErrors struct { tcpip.SendErrors @@ -256,6 +260,8 @@ type SendErrors struct { } // Stats holds statistics about the endpoint. +// +// +stateify savable type Stats struct { // SegmentsReceived is the number of TCP segments received that // the transport layer successfully parsed. @@ -310,15 +316,6 @@ type rcvQueueInfo struct { rcvQueue segmentList `state:"wait"` } -// +stateify savable -type accepted struct { - // NB: this could be an endpointList, but ilist only permits endpoints to - // belong to one list at a time, and endpoints are already stored in the - // dispatcher's list. - endpoints list.List `state:".([]*endpoint)"` - cap int -} - // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -334,7 +331,7 @@ type accepted struct { // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects accepted. +// e.acceptMu -> Protects e.acceptQueue. // e.rcvQueueMu -> Protects e.rcvQueue and associated fields. // e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. @@ -378,6 +375,7 @@ type endpoint struct { // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` + protocol *protocol `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` uniqueID uint64 @@ -497,10 +495,6 @@ type endpoint struct { // and dropped when it is. segmentQueue segmentQueue `state:"wait"` - // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state; this is only accessed atomically. - synRcvdCount int32 - // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 @@ -573,7 +567,8 @@ type endpoint struct { // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. - accepted accepted + // +checklocks:acceptMu + acceptQueue acceptQueue // The following are only used from the protocol goroutine, and // therefore don't need locks to protect them. @@ -606,8 +601,7 @@ type endpoint struct { gso stack.GSO - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats Stats `state:"nosave"` + stats Stats // tcpLingerTimeout is the maximum amount of a time a socket // a socket stays in TIME_WAIT state before being marked @@ -803,9 +797,10 @@ type keepalive struct { waker sleep.Waker `state:"nosave"` } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: s, + stack: s, + protocol: protocol, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: header.TCPProtocolNumber, @@ -874,7 +869,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.TSOffset = timeStampOffset(e.stack.Rand()) + e.acceptCond = sync.NewCond(&e.acceptMu) e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker) @@ -903,7 +898,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Check if there's anything in the accepted queue. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() - if e.accepted.endpoints.Len() != 0 { + if e.acceptQueue.endpoints.Len() != 0 { result |= waiter.ReadableEvents } e.acceptMu.Unlock() @@ -1086,20 +1081,20 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - acceptedCopy := e.accepted - e.accepted = accepted{} - e.acceptMu.Unlock() - - if acceptedCopy == (accepted{}) { - return + // Close any endpoints in SYN-RCVD state. + for n := range e.acceptQueue.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) } - - e.acceptCond.Broadcast() - + e.acceptQueue.pendingEndpoints = nil // Reset all connections that are waiting to be accepted. - for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { + for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() { n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) } + e.acceptQueue.endpoints.Init() + e.acceptMu.Unlock() + + e.acceptCond.Broadcast() + // Wait for reset of all endpoints that are still waiting to be delivered to // the now closed accepted. e.pendingAccepted.Wait() @@ -1717,6 +1712,27 @@ func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) { return rcvBufSz } +// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize. +func (e *endpoint) OnSetSendBufferSize(sz int64) int64 { + atomic.StoreUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled, 1) + return sz +} + +// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters. +func (e *endpoint) WakeupWriters() { + e.LockUser() + defer e.UnlockUser() + + sendBufferSize := e.getSendBufferSize() + e.sndQueueInfo.sndQueueMu.Lock() + notify := (sendBufferSize - e.sndQueueInfo.SndBufUsed) >= e.sndQueueInfo.SndBufUsed>>1 + e.sndQueueInfo.sndQueueMu.Unlock() + + if notify { + e.waiterQueue.Notify(waiter.WritableEvents) + } +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -2038,7 +2054,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber) e.UnlockUser() if err != nil { return err @@ -2177,7 +2193,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) - h := jenkins.Sum32(e.stack.Seed()) + h := jenkins.Sum32(e.protocol.portOffsetSecret) for _, s := range [][]byte{ []byte(e.ID.LocalAddress), []byte(e.ID.RemoteAddress), @@ -2329,6 +2345,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.setEndpointState(StateEstablished) + // Set the new auto tuned send buffer size after entering + // established state. + e.ops.SetSendBufferSize(e.computeTCPSendBufferSize(), false /* notify */) } if run { @@ -2355,6 +2374,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() + + if e.EndpointState().connecting() { + // When calling shutdown(2) on a connecting socket, the endpoint must + // enter the error state. But this logic cannot belong to the shutdownLocked + // method because that method is called during a close(2) (and closing a + // connecting socket is not an error). + e.resetConnectionLocked(&tcpip.ErrConnectionReset{}) + e.notifyProtocolGoroutine(notifyShutdown) + e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) + return nil + } + return e.shutdownLocked(flags) } @@ -2455,21 +2486,22 @@ func (e *endpoint) listen(backlog int) tcpip.Error { if e.EndpointState() == StateListen && !e.closed { e.acceptMu.Lock() defer e.acceptMu.Unlock() - if e.accepted == (accepted{}) { - // listen is called after shutdown. - e.accepted.cap = backlog - e.shutdownFlags = 0 - e.rcvQueueInfo.rcvQueueMu.Lock() - e.rcvQueueInfo.RcvClosed = false - e.rcvQueueInfo.rcvQueueMu.Unlock() - } else { - // Adjust the size of the backlog iff we can fit - // existing pending connections into the new one. - if e.accepted.endpoints.Len() > backlog { - return &tcpip.ErrInvalidEndpointState{} - } - e.accepted.cap = backlog + + // Adjust the size of the backlog iff we can fit + // existing pending connections into the new one. + if e.acceptQueue.endpoints.Len() > backlog { + return &tcpip.ErrInvalidEndpointState{} } + e.acceptQueue.capacity = backlog + + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) + } + + e.shutdownFlags = 0 + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = false + e.rcvQueueInfo.rcvQueueMu.Unlock() // Notify any blocked goroutines that they can attempt to // deliver endpoints again. @@ -2505,8 +2537,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() - if e.accepted == (accepted{}) { - e.accepted.cap = backlog + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) + } + if e.acceptQueue.capacity == 0 { + e.acceptQueue.capacity = backlog } e.acceptMu.Unlock() @@ -2546,8 +2581,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. // Get the new accepted endpoint. var n *endpoint e.acceptMu.Lock() - if element := e.accepted.endpoints.Front(); element != nil { - n = e.accepted.endpoints.Remove(element).(*endpoint) + if element := e.acceptQueue.endpoints.Front(); element != nil { + n = e.acceptQueue.endpoints.Remove(element).(*endpoint) } e.acceptMu.Unlock() if n == nil { @@ -2763,13 +2798,20 @@ func (e *endpoint) updateSndBufferUsage(v int) { e.sndQueueInfo.sndQueueMu.Lock() notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1 e.sndQueueInfo.SndBufUsed -= v + + // Get the new send buffer size with auto tuning, but do not set it + // unless we decide to notify the writers. + newSndBufSz := e.computeTCPSendBufferSize() + // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1 + notify = notify && e.sndQueueInfo.SndBufUsed < int(newSndBufSz)>>1 e.sndQueueInfo.sndQueueMu.Unlock() if notify { + // Set the new send buffer size calculated from auto tuning. + e.ops.SetSendBufferSize(newSndBufSz, false /* notify */) e.waiterQueue.Notify(waiter.WritableEvents) } } @@ -2873,46 +2915,29 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, // maybeEnableTimestamp marks the timestamp option enabled for this endpoint if // the SYN options indicate that timestamp option was negotiated. It also // initializes the recentTS with the value provided in synOpts.TSval. -func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { +func (e *endpoint) maybeEnableTimestamp(synOpts header.TCPSynOptions) { if synOpts.TS { e.SendTSOk = true e.setRecentTimestamp(synOpts.TSVal) } } -// timestamp returns the timestamp value to be used in the TSVal field of the -// timestamp option for outgoing TCP segments for a given endpoint. -func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset) +func (e *endpoint) tsVal(now tcpip.MonotonicTime) uint32 { + return e.TSOffset.TSVal(now) } -// tcpTimeStamp returns a timestamp offset by the provided offset. This is -// not inlined above as it's used when SYN cookies are in use and endpoint -// is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 { - d := curTime.Sub(tcpip.MonotonicTime{}) - return uint32(d.Milliseconds()) + offset +func (e *endpoint) tsValNow() uint32 { + return e.tsVal(e.stack.Clock().NowMonotonic()) } -// timeStampOffset returns a randomized timestamp offset to be used when sending -// timestamp values in a timestamp option for a TCP segment. -func timeStampOffset(rng *rand.Rand) uint32 { - // Initialize a random tsOffset that will be added to the recentTS - // everytime the timestamp is sent when the Timestamp option is enabled. - // - // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on - // why this is required. - // - // NOTE: This is not completely to spec as normally this should be - // initialized in a manner analogous to how sequence numbers are - // randomized per connection basis. But for now this is sufficient. - return rng.Uint32() +func (e *endpoint) elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration { + return e.TSOffset.Elapsed(now, tsEcr) } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. -func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { +func (e *endpoint) maybeEnableSACKPermitted(synOpts header.TCPSynOptions) { var v tcpip.TCPSACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. @@ -2974,6 +2999,8 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { } s.Sender.RACKState = e.snd.rc.TCPRACKState + s.Sender.RetransmitTS = e.snd.retransmitTS + s.Sender.SpuriousRecovery = e.snd.spuriousRecovery return s } @@ -3091,3 +3118,36 @@ func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOpti Max: ss.Max, } } + +// computeTCPSendBufferSize implements auto tuning of send buffer size and +// returns the new send buffer size. +func (e *endpoint) computeTCPSendBufferSize() int64 { + curSndBufSz := int64(e.getSendBufferSize()) + + // Auto tuning is disabled when the user explicitly sets the send + // buffer size with SO_SNDBUF option. + if disabled := atomic.LoadUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled); disabled == 1 { + return curSndBufSz + } + + const packetOverheadFactor = 2 + curMSS := e.snd.MaxPayloadSize + numSeg := InitialCwnd + if numSeg < e.snd.SndCwnd { + numSeg = e.snd.SndCwnd + } + + // SndCwnd indicates the number of segments that can be sent. This means + // that the sender can send upto #SndCwnd segments and the send buffer + // size should be set to SndCwnd*MSS to accommodate sending of all the + // segments. + newSndBufSz := int64(numSeg * curMSS * packetOverheadFactor) + if newSndBufSz < curSndBufSz { + return curSndBufSz + } + if ss := GetTCPSendBufferLimits(e.stack); int64(ss.Max) < newSndBufSz { + newSndBufSz = int64(ss.Max) + } + + return newSndBufSz +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 952ccacdd..94072a115 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() { } // saveEndpoints is invoked by stateify. -func (a *accepted) saveEndpoints() []*endpoint { +func (a *acceptQueue) saveEndpoints() []*endpoint { acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { acceptedEndpoints[i] = e.Value.(*endpoint) @@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint { } // loadEndpoints is invoked by stateify. -func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { +func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) { for _, ep := range acceptedEndpoints { a.endpoints.PushBack(ep) } @@ -170,6 +170,7 @@ func (e *endpoint) Resume(s *stack.Stack) { snd.probeTimer.init(s.Clock(), &snd.probeWaker) } e.stack = s + e.protocol = protocolFromStack(s) e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() epState := EndpointState(e.origEndpointState) @@ -250,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() - backlog := e.accepted.cap + e.acceptMu.Lock() + backlog := e.acceptQueue.capacity + e.acceptMu.Unlock() if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2e709ed78..128ef09e3 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -54,7 +54,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), - listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), + listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), } } @@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } f := r.forwarder - ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{ + ep, err := f.listen.performHandshake(r.segment, header.TCPSynOptions{ MSS: r.synOptions.MSS, WS: r.synOptions.WS, TS: r.synOptions.TS, diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 18b834243..e4410ad93 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -23,8 +23,10 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" + "gvisor.dev/gvisor/pkg/tcpip/internal/tcp" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -49,10 +51,6 @@ const ( // MaxBufferSize is the largest size a receive/send buffer can grow to. MaxBufferSize = 4 << 20 // 4MB - // MaxUnprocessedSegments is the maximum number of unprocessed segments - // that can be queued for a given endpoint. - MaxUnprocessedSegments = 300 - // DefaultTCPLingerTimeout is the amount of time that sockets linger in // FIN_WAIT_2 state before being marked closed. DefaultTCPLingerTimeout = 60 * time.Second @@ -96,6 +94,11 @@ type protocol struct { maxRetries uint32 synRetries uint8 dispatcher dispatcher + + // The following secrets are initialized once and stay unchanged after. + seqnumSecret uint32 + portOffsetSecret uint32 + tsOffsetSecret uint32 } // Number returns the tcp protocol number. @@ -105,7 +108,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { // NewEndpoint creates a new tcp endpoint. func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return newEndpoint(p.stack, netProto, waiterQueue), nil + return newEndpoint(p.stack, p, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently @@ -156,6 +159,24 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, return stack.UnknownDestinationPacketHandled } +func (p *protocol) tsOffset(src, dst tcpip.Address) tcp.TSOffset { + // Initialize a random tsOffset that will be added to the recentTS + // everytime the timestamp is sent when the Timestamp option is enabled. + // + // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on + // why this is required. + // + // TODO(https://gvisor.dev/issues/6473): This is not really secure as + // it does not use the recommended algorithm linked above. + h := jenkins.Sum32(p.tsOffsetSecret) + // Per hash.Hash.Writer: + // + // It never returns an error. + _, _ = h.Write([]byte(src)) + _, _ = h.Write([]byte(dst)) + return tcp.NewTSOffset(h.Sum32()) +} + // replyWithReset replies to the given segment with a reset segment. // // If the passed TTL is 0, then the route's default TTL will be used. @@ -292,22 +313,26 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip case *tcpip.TCPMinRTOOption: p.mu.Lock() + defer p.mu.Unlock() if *v < 0 { p.minRTO = MinRTO + } else if minRTO := time.Duration(*v); minRTO <= p.maxRTO { + p.minRTO = minRTO } else { - p.minRTO = time.Duration(*v) + return &tcpip.ErrInvalidOptionValue{} } - p.mu.Unlock() return nil case *tcpip.TCPMaxRTOOption: p.mu.Lock() + defer p.mu.Unlock() if *v < 0 { p.maxRTO = MaxRTO + } else if maxRTO := time.Duration(*v); maxRTO >= p.minRTO { + p.maxRTO = maxRTO } else { - p.maxRTO = time.Duration(*v) + return &tcpip.ErrInvalidOptionValue{} } - p.mu.Unlock() return nil case *tcpip.TCPMaxRetriesOption: @@ -479,7 +504,15 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { maxRTO: MaxRTO, maxRetries: MaxRetries, recovery: tcpip.TCPRACKLossDetection, + seqnumSecret: s.Rand().Uint32(), + portOffsetSecret: s.Rand().Uint32(), + tsOffsetSecret: s.Rand().Uint32(), } p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p } + +// protocolFromStack retrieves the tcp.protocol instance from stack s. +func protocolFromStack(s *stack.Stack) *protocol { + return s.TransportProtocolInstance(ProtocolNumber).(*protocol) +} diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 0da4eafaa..3b055c294 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -80,7 +80,6 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime) - tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: @@ -92,7 +91,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // step 2 if seg.xmitCount > 1 { if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { - if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) { + if ackSeg.parsedOptions.TSEcr < rc.snd.ep.tsVal(seg.xmitTime) { return } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 9ce8fcae9..90e493978 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -477,7 +477,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 { + if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && (r.PendingBufUsed+int(segLen)) < int(rcvBufSize)>>2 { r.ep.rcvQueueInfo.rcvQueueMu.Lock() r.PendingBufUsed += s.segMemSize() r.ep.rcvQueueInfo.rcvQueueMu.Unlock() diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index 8a026ec46..e47a07030 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rcv_test +package tcp_test import ( "testing" diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 2e6ea06f5..2d5fdda19 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW DataSize: seg.data.Size(), SegMemSize: seg.segMemSize(), } - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("%s differs (-want +got):\n%s", name, diff) } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 92a66f17e..4377f07a0 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -144,6 +144,15 @@ type sender struct { // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm. probeTimer timer `state:"nosave"` probeWaker sleep.Waker `state:"nosave"` + + // spuriousRecovery indicates whether the sender entered recovery + // spuriously as described in RFC3522 Section 3.2. + spuriousRecovery bool + + // retransmitTS is the timestamp at which the sender sends retransmitted + // segment after entering an RTO for the first time as described in + // RFC3522 Section 3.2. + retransmitTS uint32 } // rtt is a synchronization wrapper used to appease stateify. See the comment @@ -382,6 +391,9 @@ func (s *sender) updateRTO(rtt time.Duration) { if s.RTO < s.minRTO { s.RTO = s.minRTO } + if s.RTO > s.maxRTO { + s.RTO = s.maxRTO + } } // resendSegment resends the first unacknowledged segment. @@ -422,6 +434,13 @@ func (s *sender) retransmitTimerExpired() bool { return true } + // Initialize the variables used to detect spurious recovery after + // entering RTO. + // + // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1. + s.spuriousRecovery = false + s.retransmitTS = 0 + // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases // when writeList is empty. Remove this once we have a proper fix for this // issue. @@ -492,6 +511,10 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveRecovery() } + // Record retransmitTS if the sender is not in recovery as per: + // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + s.recordRetransmitTS() + s.state = tcpip.RTORecovery s.cc.HandleRTOExpired() @@ -955,6 +978,13 @@ func (s *sender) sendData() { } func (s *sender) enterRecovery() { + // Initialize the variables used to detect spurious recovery after + // entering recovery. + // + // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1. + s.spuriousRecovery = false + s.retransmitTS = 0 + s.FastRecovery.Active = true // Save state to reflect we're now in fast recovery. // @@ -969,6 +999,11 @@ func (s *sender) enterRecovery() { s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding s.FastRecovery.HighRxt = s.SndUna s.FastRecovery.RescueRxt = s.SndUna + + // Record retransmitTS if the sender is not in recovery as per: + // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + s.recordRetransmitTS() + if s.ep.SACKPermitted { s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() @@ -1144,13 +1179,15 @@ func (s *sender) isDupAck(seg *segment) bool { // Iterate the writeList and update RACK for each segment which is newly acked // either cumulatively or selectively. Loop through the segments which are // sacked, and update the RACK related variables and check for reordering. +// Returns true when the DSACK block has been detected in the received ACK. // // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. -func (s *sender) walkSACK(rcvdSeg *segment) { +func (s *sender) walkSACK(rcvdSeg *segment) bool { s.rc.setDSACKSeen(false) // Look for DSACK block. + hasDSACK := false idx := 0 n := len(rcvdSeg.parsedOptions.SACKBlocks) if checkDSACK(rcvdSeg) { @@ -1164,10 +1201,11 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.setDSACKSeen(true) idx = 1 n-- + hasDSACK = true } if n == 0 { - return + return hasDSACK } // Sort the SACK blocks. The first block is the most recent unacked @@ -1190,6 +1228,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { seg = seg.Next() } } + return hasDSACK } // checkDSACK checks if a DSACK is reported. @@ -1236,6 +1275,85 @@ func checkDSACK(rcvdSeg *segment) bool { return false } +func (s *sender) recordRetransmitTS() { + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 + // + // The Eifel detection algorithm is used, only upon initiation of loss + // recovery, i.e., when either the timeout-based retransmit or the fast + // retransmit is sent. The Eifel detection algorithm MUST NOT be + // reinitiated after loss recovery has already started. In particular, + // it must not be reinitiated upon subsequent timeouts for the same + // segment, and not upon retransmitting segments other than the oldest + // outstanding segment, e.g., during selective loss recovery. + if s.inRecovery() { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2 + // + // Set a "RetransmitTS" variable to the value of the Timestamp Value + // field of the Timestamps option included in the retransmit sent when + // loss recovery is initiated. A TCP sender must ensure that + // RetransmitTS does not get overwritten as loss recovery progresses, + // e.g., in case of a second timeout and subsequent second retransmit of + // the same octet. + s.retransmitTS = s.ep.tsValNow() +} + +func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) { + // Return if the sender has already detected spurious recovery. + if s.spuriousRecovery { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 4 + // + // If the value of the Timestamp Echo Reply field of the acceptable ACK's + // Timestamps option is smaller than the value of RetransmitTS, then + // proceed to next step, else return. + if tsEchoReply >= s.retransmitTS { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5 + // + // If the acceptable ACK carries a DSACK option [RFC2883], then return. + if hasDSACK { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5 + // + // If during the lifetime of the TCP connection the TCP sender has + // previously received an ACK with a DSACK option, or the acceptable ACK + // does not acknowledge all outstanding data, then proceed to next step, + // else return. + numDSACK := s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.Value() + if numDSACK == 0 && s.SndUna == s.SndNxt { + return + } + + // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 6 + // + // If the loss recovery has been initiated with a timeout-based + // retransmit, then set + // SpuriousRecovery <- SPUR_TO (equal 1), + // else set + // SpuriousRecovery <- dupacks+1 + // Set the spurious recovery variable to true as we do not differentiate + // between fast, SACK or RTO recovery. + s.spuriousRecovery = true + s.ep.stack.Stats().TCP.SpuriousRecovery.Increment() +} + +// Check if the sender is in RTORecovery, FastRecovery or SACKRecovery state. +func (s *sender) inRecovery() bool { + if s.state == tcpip.RTORecovery || s.state == tcpip.FastRecovery || s.state == tcpip.SACKRecovery { + return true + } + return false +} + // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { @@ -1251,6 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Insert SACKBlock information into our scoreboard. + hasDSACK := false if s.ep.SACKPermitted { for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds @@ -1285,7 +1404,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // RACK.fack, then the corresponding packet has been // reordered and RACK.reord is set to TRUE. if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { - s.walkSACK(rcvdSeg) + hasDSACK = s.walkSACK(rcvdSeg) } s.SetPipe() } @@ -1342,10 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // some new data, i.e., only if it advances the left edge of // the send window. if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { - // TSVal/Ecr values sent by Netstack are at a millisecond - // granularity. - elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond - s.updateRTO(elapsed) + s.updateRTO(s.ep.elapsed(s.ep.stack.Clock().NowMonotonic(), rcvdSeg.parsedOptions.TSEcr)) } if s.shouldSchedulePTO() { @@ -1415,12 +1531,14 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ackLeft -= datalen } - // Update the send buffer usage and notify potential waiters. - s.ep.updateSndBufferUsage(int(acked)) - // Clear SACK information for all acked data. s.ep.scoreboard.Delete(s.SndUna) + // Detect if the sender entered recovery spuriously. + if s.inRecovery() { + s.detectSpuriousRecovery(hasDSACK, rcvdSeg.parsedOptions.TSEcr) + } + // If we are not in fast recovery then update the congestion // window based on the number of acknowledged packets. if !s.FastRecovery.Active { @@ -1437,6 +1555,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } + // Update the send buffer usage and notify potential waiters. + s.ep.updateSndBufferUsage(int(acked)) + // It is possible for s.outstanding to drop below zero if we get // a retransmit timeout, reset outstanding to zero but later // get an ack that cover previously sent data. diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 89e9fb886..0d36d0dd0 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -33,7 +33,6 @@ const ( tsOptionSize = 12 maxTCPOptionSize = 40 mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload - latency = 5 * time.Millisecond ) func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) { @@ -163,7 +162,10 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en if !enableRACK { setStackTCPRecovery(t, c, 0) } - createConnectedWithSACKAndTS(c) + // The delay should be below initial RTO (1s) otherwise retransimission + // will start. Choose a relatively large value so that estimated RTT + // keeps high even after a few rounds of undelayed RTT samples. + c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */) data := make([]byte, numPackets*maxPayload) for i := range data { @@ -181,9 +183,6 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en for i := 0; i < numPackets; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload - // This delay is added to increase RTT as low RTT can cause TLP - // before sending ACK. - time.Sleep(latency) } return data @@ -1060,16 +1059,17 @@ func TestRACKWithWindowFull(t *testing.T) { for i := 0; i < numPkts; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload - if i == 0 { - // Send ACK for the first packet to establish RTT. - c.SendAck(seq, maxPayload) - } } - // SACK for #10 packet. - start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) + // Expect retransmission of last packet due to TLP. + c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, tsOptionSize) + + // SACK for first and last packet. + start := c.IRS.Add(seqnum.Size(maxPayload)) end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}}) + dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) var info tcpip.TCPInfoOption if err := c.EP.GetSockOpt(&info); err != nil { diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 83e0653b9..896249d2d 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -23,6 +23,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -35,13 +36,13 @@ import ( // SACKPermitted option enabled if the stack in the context has the SACK support // enabled. func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) + return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) } // createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS // option enabled if the stack in the context has SACK and TS enabled. func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) + return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) } func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { @@ -108,7 +109,7 @@ func TestSackDisabledConnect(t *testing.T) { setStackSACKPermitted(t, c, sackEnabled) setStackTCPRecovery(t, c, 0) - rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) + rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) data := []byte{1, 2, 3} @@ -170,7 +171,7 @@ func TestSackPermittedAccept(t *testing.T) { setStackSACKPermitted(t, c, sackEnabled) setStackTCPRecovery(t, c, 0) - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) + rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) // Now verify no SACK blocks are // received when sack is disabled. data := []byte{1, 2, 3} @@ -244,7 +245,7 @@ func TestSackDisabledAccept(t *testing.T) { setStackSACKPermitted(t, c, sackEnabled) setStackTCPRecovery(t, c, 0) - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Now verify no SACK blocks are // received when sack is disabled. @@ -702,3 +703,257 @@ func TestRecoveryEntry(t *testing.T) { t.Error(err) } } + +func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery uint64) { + t.Helper() + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil + } + + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } +} + +func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b, data []byte) { + payloadLen := uint32(len(tcpHdr.Payload())) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead), + checker.TCPAckNum(context.TestInitialSequenceNumber+1), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + pdata := data[bytesRead : bytesRead+payloadLen] + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { + t.Fatalf("got data = %v, want = %v", p, pdata) + } +} + +func buildTSOptionFromHeader(tcpHdr header.TCP) []byte { + parsedOpts := tcpHdr.ParsedOptions() + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) + return tsOpt[:] +} + +func TestDetectSpuriousRecoveryWithRTO(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { + if s.Sender.RetransmitTS == 0 { + t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + } + if !s.Sender.SpuriousRecovery { + t.Fatalf("Spurious recovery was not detected") + } + close(probeDone) + }) + + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := make([]byte, numPackets*maxPayload) + for i := range data { + data[i] = byte(i) + } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + var options []byte + var bytesRead uint32 + for i := 0; i < numPackets; i++ { + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) + + // Get options only for the first packet. This will be sent with + // the ACK to indicate the acknowledgement is for the original + // packet. + if i == 0 && c.TimeStampEnabled { + options = buildTSOptionFromHeader(tcpHdr) + } + bytesRead += uint32(len(tcpHdr.Payload())) + } + + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Expect #5 segment with TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + // Expect #1 segment because of RTO. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + info := tcpip.TCPInfoOption{} + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + + if info.CcState != tcpip.RTORecovery { + t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery) + } + + // Acknowledge the data. + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone + + verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */) +} + +func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + numAck := 0 + probeDone := make(chan struct{}) + c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { + if numAck < 3 { + numAck++ + return + } + + if s.Sender.RetransmitTS == 0 { + t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + } + if !s.Sender.SpuriousRecovery { + t.Fatalf("Spurious recovery was not detected") + } + close(probeDone) + }) + + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := make([]byte, numPackets*maxPayload) + for i := range data { + data[i] = byte(i) + } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + var options []byte + var bytesRead uint32 + for i := 0; i < numPackets; i++ { + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) + + // Get options only for the first packet. This will be sent with + // the ACK to indicate the acknowledgement is for the original + // packet. + if i == 0 && c.TimeStampEnabled { + options = buildTSOptionFromHeader(tcpHdr) + } + bytesRead += uint32(len(tcpHdr.Payload())) + } + + // Receive the retransmitted packet after TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Send ACK for #3 and #4 segments to avoid entering TLP. + start := c.IRS.Add(3*maxPayload + 1) + end := start.Add(2 * maxPayload) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + + c.SendAck(seq, 0 /* bytesReceived */) + c.SendAck(seq, 0 /* bytesReceived */) + + // Receive the retransmitted packet after three duplicate ACKs. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + info := tcpip.TCPInfoOption{} + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + + if info.CcState != tcpip.SACKRecovery { + t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery) + } + + // Acknowledge the data. + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone + + verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */) +} + +func TestNoSpuriousRecoveryWithDSACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + numPackets := 5 + data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) + + // Receive the retransmitted packet after TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + // Send ACK for #3 and #4 segments to avoid entering TLP. + start := c.IRS.Add(3*maxPayload + 1) + end := start.Add(2 * maxPayload) + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + + c.SendAck(seq, 0 /* bytesReceived */) + c.SendAck(seq, 0 /* bytesReceived */) + + // Receive the retransmitted packet after three duplicate ACKs. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + // Acknowledge the data with DSACK for #1 segment. + start = c.IRS.Add(maxPayload + 1) + end = start.Add(2 * maxPayload) + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}}) + + verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 031f01357..6f1ee3816 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -1381,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) { if err := s.CreateNIC(id, ep); err != nil { t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) } - if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { - t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{ {Destination: header.IPv4EmptySubnet, NIC: id}, @@ -1651,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) { } } +func TestShutdownConnectingSocket(t *testing.T) { + for _, test := range []struct { + name string + shutdownMode tcpip.ShutdownFlags + }{ + {"ShutdownRead", tcpip.ShutdownRead}, + {"ShutdownWrite", tcpip.ShutdownWrite}, + {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with + // the handshake process. + c.Create(-1) + + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + // Start connection attempt. + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) + } + + // Check the SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + if err := c.EP.Shutdown(test.shutdownMode); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + // The endpoint internal state is updated immediately. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + select { + case <-ch: + default: + t.Fatal("endpoint was not notified") + } + + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrConnectionReset{}) + + // If the endpoint is not properly shutdown, it'll re-attempt to connect + // by sending another ACK packet. + c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond)) + }) + } +} + func TestSynSent(t *testing.T) { for _, test := range []struct { name string @@ -1674,7 +1744,7 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } @@ -1990,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { ) // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the FIN but DON't ACK IT. checker.IPv4(t, c.GetPacket(), @@ -2006,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // Cause a RST to be generated by closing the read end now since we have // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the RST checker.IPv4(t, c.GetPacket(), @@ -2127,6 +2201,214 @@ func TestFullWindowReceive(t *testing.T) { ) } +func TestSmallReceiveBufferReadiness(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + }) + + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } + + const nicID = 1 + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err) + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x7f\x00\x00\x01"), + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err) + } + + { + subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") + if err != nil { + t.Fatalf("tcpip.NewSubnet failed: %s", err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: subnet, + NIC: nicID, + }, + }) + } + + listenerEntry, listenerCh := waiter.NewChannelEntry(nil) + var listenerWQ waiter.Queue + listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer listener.Close() + listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents) + defer listenerWQ.EventUnregister(&listenerEntry) + + if err := listener.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := listener.Listen(1); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + localAddress, err := listener.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress failed: %s", err) + } + + for i := 8; i > 0; i /= 2 { + size := int64(i << 10) + t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) { + var clientWQ waiter.Queue + client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer client.Close() + switch err := client.Connect(localAddress).(type) { + case nil: + t.Fatal("Connect returned nil error") + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect failed: %s", err) + } + + <-listenerCh + server, serverWQ, err := listener.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + defer server.Close() + + client.SocketOptions().SetReceiveBufferSize(size, true) + // Send buffer size doesn't seem to affect this test. + // server.SocketOptions().SetSendBufferSize(size, true) + + clientEntry, clientCh := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents) + defer clientWQ.EventUnregister(&clientEntry) + + serverEntry, serverCh := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverEntry, waiter.WritableEvents) + defer serverWQ.EventUnregister(&serverEntry) + + var total int64 + for { + var b [64 << 10]byte + var r bytes.Reader + r.Reset(b[:]) + switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + case nil: + t.Logf("wrote %d bytes", n) + total += n + continue + case *tcpip.ErrWouldBlock: + select { + case <-serverCh: + continue + case <-time.After(100 * time.Millisecond): + // Well and truly full. + t.Logf("send and receive queues are full") + } + default: + t.Fatalf("Write failed: %s", err) + } + break + } + t.Logf("wrote %d bytes in total", total) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(2) + go func() { + defer wg.Done() + + var b [64 << 10]byte + var r bytes.Reader + r.Reset(b[:]) + if err := func() error { + var total int64 + defer t.Logf("wrote %d bytes in total", total) + for r.Len() != 0 { + switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + case nil: + t.Logf("wrote %d bytes", n) + total += n + case *tcpip.ErrWouldBlock: + for { + t.Logf("waiting on server") + select { + case <-serverCh: + case <-time.After(time.Second): + if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 { + t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness) + } + continue + } + break + } + default: + return fmt.Errorf("server.Write failed: %s", err) + } + } + if err := server.Shutdown(tcpip.ShutdownWrite); err != nil { + return fmt.Errorf("server.Shutdown failed: %s", err) + } + t.Logf("server end shutdown done") + return nil + }(); err != nil { + t.Error(err) + } + }() + + go func() { + defer wg.Done() + + if err := func() error { + total := 0 + defer t.Logf("read %d bytes in total", total) + for { + switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { + case nil: + t.Logf("read %d bytes", res.Count) + total += res.Count + t.Logf("read total %d bytes till now", total) + case *tcpip.ErrClosedForReceive: + return nil + case *tcpip.ErrWouldBlock: + for { + t.Logf("waiting on client") + select { + case <-clientCh: + case <-time.After(time.Second): + if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 { + return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness) + } + continue + } + break + } + default: + return fmt.Errorf("client.Write failed: %s", err) + } + } + }(); err != nil { + t.Error(err) + } + }() + }) + } +} + // Test the stack receive window advertisement on receiving segments smaller than // segment overhead. It tests for the right edge of the window to not grow when // the endpoint is not being read from. @@ -2143,7 +2425,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) } - c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) + c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) // Bump up the receive buffer size such that, when the receive window grows, // the scaled window exceeds maxUint16. @@ -2535,7 +2817,7 @@ func TestScaledWindowAccept(t *testing.T) { // Do 3-way handshake. // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 - c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) + c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) @@ -3532,6 +3814,12 @@ func TestMaxRetransmitsTimeout(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } + // Wait for the connection to timeout after MaxRetries retransmits. + initRTO := time.Second + minRTOOpt := tcpip.TCPMinRTOOption(initRTO) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -3554,8 +3842,6 @@ func TestMaxRetransmitsTimeout(t *testing.T) { ), ) } - // Wait for the connection to timeout after MaxRetries retransmits. - initRTO := 1 * time.Second select { case <-notifyCh: case <-time.After((2 << numRetries) * initRTO): @@ -3590,9 +3876,13 @@ func TestMaxRTO(t *testing.T) { defer c.Cleanup() rto := 1 * time.Second - opt := tcpip.TCPMaxRTOOption(rto) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + minRTOOpt := tcpip.TCPMinRTOOption(rto / 2) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } + maxRTOOpt := tcpip.TCPMaxRTOOption(rto) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err) } c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3618,8 +3908,8 @@ func TestMaxRTO(t *testing.T) { checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) - if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { - t.Errorf("Retransmit interval not capped to MaxRTO.\n") + if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() { + t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed) } } } @@ -3670,6 +3960,10 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + minRTOOpt := tcpip.TCPMinRTOOption(time.Second) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) // Disabling PMTU discovery causes all packets sent from this socket to @@ -4736,13 +5030,17 @@ func makeStack() (*stack.Stack, tcpip.Error) { } for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address + number tcpip.NetworkProtocolNumber + addrWithPrefix tcpip.AddressWithPrefix }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, + {ipv4.ProtocolNumber, context.StackAddrWithPrefix}, + {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix}, } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ct.number, + AddressWithPrefix: ct.addrWithPrefix, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { return nil, err } } @@ -4946,7 +5244,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) } for i := start; i <= end; i++ { - if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { + if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { t.Fatalf("Bind(%d) failed: %s", i, err) } } @@ -6304,7 +6602,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) @@ -6385,7 +6683,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // maximum buffer size defined above. c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) + rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4}) // NOTE: The timestamp values in the sent packets are meaningless to the // peer so we just increment the timestamp value by 1 every batch as we @@ -6515,7 +6813,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // maximum buffer size used by stack. c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) + rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4}) tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) @@ -7430,6 +7728,11 @@ func TestTCPUserTimeout(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + initRTO := 1 * time.Second + minRTOOpt := tcpip.TCPMinRTOOption(initRTO) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -7440,7 +7743,6 @@ func TestTCPUserTimeout(t *testing.T) { // Ensure that on the next retransmit timer fire, the user timeout has // expired. - initRTO := 1 * time.Second userTimeout := initRTO / 2 v := tcpip.TCPUserTimeoutOption(userTimeout) if err := c.EP.SetSockOpt(&v); err != nil { @@ -7954,6 +8256,151 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } } +func TestHandshakeRTT(t *testing.T) { + type testCase struct { + connect bool + tsEnabled bool + useCookie bool + retrans bool + delay time.Duration + wantRTT time.Duration + } + var testCases []testCase + for _, connect := range []bool{false, true} { + for _, tsEnabled := range []bool{false, true} { + for _, useCookie := range []bool{false, true} { + for _, retrans := range []bool{false, true} { + if connect && useCookie { + continue + } + delay := 800 * time.Millisecond + if retrans { + delay = 1200 * time.Millisecond + } + wantRTT := delay + // If syncookie is enabled, sample RTT only when TS option is enabled. + if !retrans && useCookie && !tsEnabled { + wantRTT = 0 + } + // If retransmitted, sample RTT only when TS option is enabled. + if retrans && !tsEnabled { + wantRTT = 0 + } + testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT}) + } + } + } + } + for _, tt := range testCases { + tt := tt + t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) { + t.Parallel() + c := context.New(t, defaultMTU) + if tt.useCookie { + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + } + synOpts := header.TCPSynOptions{} + if tt.tsEnabled { + synOpts.TS = true + synOpts.TSVal = 42 + } + if tt.connect { + c.CreateConnectedWithOptions(synOpts, tt.delay) + } else { + synOpts.MSS = defaultIPv4MSS + synOpts.WS = -1 + c.AcceptWithOptions(-1, synOpts, tt.delay) + } + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT { + t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT) + } + if info.RTTVar != 0 && tt.wantRTT == 0 { + t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar) + } + if info.RTTVar == 0 && tt.wantRTT != 0 { + t.Fatalf("got info.RTTVar=0, expect non zero") + } + }) + } +} + +func TestSetRTO(t *testing.T) { + c := context.New(t, defaultMTU) + minRTO, maxRTO := tcpRTOMinMax(t, c) + for _, tt := range []struct { + name string + RTO time.Duration + minRTO time.Duration + maxRTO time.Duration + err tcpip.Error + }{ + { + name: "invalid minRTO", + minRTO: maxRTO + time.Second, + err: &tcpip.ErrInvalidOptionValue{}, + }, + { + name: "invalid maxRTO", + maxRTO: minRTO - time.Millisecond, + err: &tcpip.ErrInvalidOptionValue{}, + }, + { + name: "valid minRTO", + minRTO: maxRTO - time.Second, + }, + { + name: "valid maxRTO", + maxRTO: minRTO + time.Millisecond, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + var opt tcpip.SettableTransportProtocolOption + if tt.minRTO > 0 { + min := tcpip.TCPMinRTOOption(tt.minRTO) + opt = &min + } + if tt.maxRTO > 0 { + max := tcpip.TCPMaxRTOOption(tt.maxRTO) + opt = &max + } + err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt) + if got, want := err, tt.err; got != want { + t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want) + } + if tt.err == nil { + minRTO, maxRTO := tcpRTOMinMax(t, c) + if tt.minRTO > 0 && tt.minRTO != minRTO { + t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO) + } + if tt.maxRTO > 0 && tt.maxRTO != maxRTO { + t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO) + } + } + }) + } +} + +func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) { + t.Helper() + var minOpt tcpip.TCPMinRTOOption + var maxOpt tcpip.TCPMaxRTOOption + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil { + t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err) + } + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil { + t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err) + } + return time.Duration(minOpt), time.Duration(maxOpt) +} + // generateRandomPayload generates a random byte slice of the specified length // causing a fatal test failure if it is unable to do so. func generateRandomPayload(t *testing.T, n int) []byte { @@ -7964,3 +8411,192 @@ func generateRandomPayload(t *testing.T, n int) []byte { } return buf } + +func TestSendBufferTuning(t *testing.T) { + const maxPayload = 536 + const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload + const packetOverheadFactor = 2 + + testCases := []struct { + name string + autoTuningDisabled bool + }{ + {"autoTuningDisabled", true}, + {"autoTuningEnabled", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + // Set the stack option for send buffer size. + const defaultSndBufSz = maxPayload * tcp.InitialCwnd + const maxSndBufSz = defaultSndBufSz * 10 + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz} + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + } + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + + oldSz := c.EP.SocketOptions().GetSendBufferSize() + if oldSz != defaultSndBufSz { + t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz) + } + + if tc.autoTuningDisabled { + c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */) + } + + data := make([]byte, maxPayload) + for i := range data { + data[i] = byte(i) + } + + w, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&w, waiter.WritableEvents) + defer c.WQ.EventUnregister(&w) + + bytesRead := 0 + for { + // Packets will be sent till the send buffer + // size is reached. + var r bytes.Reader + r.Reset(data[bytesRead : bytesRead+maxPayload]) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + break + } + + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0) + bytesRead += maxPayload + data = append(data, data...) + } + + // Send an ACK and wait for connection to become writable again. + c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) + select { + case <-ch: + if err := c.EP.LastError(); err != nil { + t.Fatalf("Write failed: %s", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } + + outSz := int64(defaultSndBufSz) + if !tc.autoTuningDisabled { + // Calculate the new auto tuned send buffer. + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload + } + + if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz { + t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz) + } + }) + } +} + +func TestTimestampSynCookies(t *testing.T) { + clock := faketime.NewManualClock() + tsNow := func() uint32 { + return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds()) + } + // Advance the clock so that NowMonotonic is non-zero. + clock.Advance(time.Second) + c := context.NewWithOpts(t, context.Options{ + EnableV4: true, + EnableV6: true, + MTU: defaultMTU, + Clock: clock, + }) + defer c.Cleanup() + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + + tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(42, 0, tcpOpts[2:]) + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + RcvWnd: seqnum.Size(512), + SeqNum: iss, + TCPOpts: tcpOpts[:], + }) + // Get the TSVal of SYN-ACK. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + initialTSVal := tcpHdr.ParsedOptions().TSVal + // derive the tsOffset. + tsOffset := initialTSVal - tsNow() + + header.EncodeTSOption(420, initialTSVal, tcpOpts[2:]) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + RcvWnd: seqnum.Size(512), + SeqNum: iss + 1, + AckNum: c.IRS + 1, + TCPOpts: tcpOpts[:], + }) + c.EP, _, err = ep.Accept(nil) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + defer wq.EventUnregister(&we) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } else if err != nil { + t.Fatalf("failed to accept: %s", err) + } + + // Advance the clock again so that we expect the next TSVal to change. + clock.Advance(time.Second) + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // The endpoint should have a correct TSOffset so that the received TSVal + // should match our expectation. + if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want { + t.Fatalf("got TSVal = %d, want %d", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 1deb1fe4d..65925daa5 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -32,7 +32,7 @@ import ( // createConnectedWithTimestampOption creates and connects c.ep with the // timestamp option enabled. func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) + return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1}) } // TestTimeStampEnabledConnect tests that netstack sends the timestamp option on @@ -131,7 +131,7 @@ func TestTimeStampDisabledConnect(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithOptions(header.TCPSynOptions{}) + c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) } func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { @@ -147,7 +147,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) tsVal := rand.Uint32() - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) + c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) // Now send some data and validate that timestamp is echoed correctly in the ACK. data := []byte{1, 2, 3} @@ -209,7 +209,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd } t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Now send some data with the accepted connection endpoint and validate // that no timestamp option is sent in the TCP segment. diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 96e4849d2..88bb99354 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -122,6 +122,9 @@ type Options struct { // MTU indicates the maximum transmission unit on the link layer. MTU uint32 + + // Clock that is used by Stack. + Clock tcpip.Clock } // Context provides an initialized Network stack and a link layer endpoint @@ -182,6 +185,7 @@ func NewWithOpts(t *testing.T, opts Options) *Context { stackOpts := stack.Options{ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: opts.Clock, } if opts.EnableV4 { stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) @@ -239,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context { Protocol: ipv4.ProtocolNumber, AddressWithPrefix: StackAddrWithPrefix, } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err) } routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, @@ -253,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context { Protocol: ipv6.ProtocolNumber, AddressWithPrefix: StackV6AddrWithPrefix, } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err) } routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, @@ -879,13 +883,21 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { ) } +// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions +// without delay. +func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint { + return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */) +} + // CreateConnectedWithOptions creates and connects c.ep with the specified TCP // options enabled and returns a RawEndpoint which represents the other end of -// the connection. +// the connection. It delays before a SYNACK is sent. This makes c.EP have a +// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps +// reduce flakiness. // // It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK // does not carry an option that was not requested. -func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { +func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint { var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { @@ -911,18 +923,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // TS value. mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{ - MSS: mss, - TS: true, - WS: int(c.WindowScale), - SACKPermitted: c.SACKEnabled(), - }), - ), + synChecker := checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{ + MSS: mss, + TS: true, + WS: int(c.WindowScale), + SACKPermitted: c.SACKEnabled(), + }), ) + checker.IPv4(c.t, b, synChecker) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) } @@ -948,6 +959,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Build SYN-ACK. c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) iss := seqnum.Value(TestInitialSequenceNumber) + if delay > 0 { + // Sleep so that RTT is increased. + time.Sleep(delay) + } c.SendPacket(nil, &Headers{ SrcPort: tcpSeg.DestinationPort(), DstPort: tcpSeg.SourcePort(), @@ -959,7 +974,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * }) // Read ACK. - ackPacket := c.GetPacket() + var ackPacket []byte + // Ignore retransimitted SYN packets. + for { + packet := c.GetPacket() + if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 { + checker.IPv4(c.t, packet, synChecker) + } else { + ackPacket = packet + break + } + } // Verify TCP header fields. tcpCheckers := []checker.TransportChecker{ @@ -1016,13 +1041,19 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * } } -// AcceptWithOptions initializes a listening endpoint and connects to it with the -// provided options enabled. It also verifies that the SYN-ACK has the expected -// values for the provided options. +// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay. +func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */) +} + +// AcceptWithOptions initializes a listening endpoint and connects to it with +// the provided options enabled. It delays before the final ACK of the 3WHS is +// sent. It also verifies that the SYN-ACK has the expected values for the +// provided options. // // The function returns a RawEndpoint representing the other end of the accepted // endpoint. -func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { +func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint { // Create EP and start listening. wq := &waiter.Queue{} ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) @@ -1045,7 +1076,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } - rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) + rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) @@ -1077,13 +1108,14 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption // PassiveConnectWithOptions. func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { synOptions.WS = -1 - c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) + c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */) } // PassiveConnectWithOptions initiates a new connection (with the specified TCP // options enabled) to the port on which the Context.ep is listening for new // connections. It also validates that the SYN-ACK has the expected values for -// the enabled options. +// the enabled options. The final ACK of the handshake is delayed by specified +// duration. // // NOTE: MSS is not a negotiated option and it can be asymmetric // in each direction. This function uses the maxPayload to set the MSS to be @@ -1093,7 +1125,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP // wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the // value of the window scaling option to be sent in the SYN. If synOptions.WS > // 0 then we send the WindowScale option. -func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { +func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint { c.t.Helper() opts := make([]byte, header.TCPOptionsMaximumSize) offset := 0 @@ -1180,7 +1212,10 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions ackHeaders.TCPOpts = opts[:] } - // Send ACK. + // Send ACK, delay if needed. + if delay > 0 { + time.Sleep(delay) + } c.SendPacket(nil, ackHeaders) c.RcvdWindowScale = uint8(rcvdSynOptions.WS) diff --git a/pkg/tcpip/transport/transport.go b/pkg/tcpip/transport/transport.go new file mode 100644 index 000000000..4c2ae87f4 --- /dev/null +++ b/pkg/tcpip/transport/transport.go @@ -0,0 +1,16 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package transport supports transport protocols. +package transport diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index cdc344ab7..d2c0963b0 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -35,6 +35,8 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/waiter", ], @@ -61,5 +63,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 82a3f2287..39b1e08c0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,8 +15,8 @@ package udp import ( + "fmt" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sync" @@ -25,12 +25,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) // +stateify savable type udpPacket struct { udpPacketEntry + netProto tcpip.NetworkProtocolNumber senderAddress tcpip.FullAddress destinationAddress tcpip.FullAddress packetInfo tcpip.IPPacketInfo @@ -40,36 +43,6 @@ type udpPacket struct { tos uint8 } -// EndpointState represents the state of a UDP endpoint. -type EndpointState tcpip.EndpointState - -// Endpoint states. Note that are represented in a netstack-specific manner and -// may not be meaningful externally. Specifically, they need to be translated to -// Linux's representation for these states if presented to userspace. -const ( - _ EndpointState = iota - StateInitial - StateBound - StateConnected - StateClosed -) - -// String implements fmt.Stringer. -func (s EndpointState) String() string { - switch s { - case StateInitial: - return "INITIAL" - case StateBound: - return "BOUND" - case StateConnected: - return "CONNECTING" - case StateClosed: - return "CLOSED" - default: - return "UNKNOWN" - } -} - // endpoint represents a UDP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -79,7 +52,6 @@ func (s EndpointState) String() string { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and do not @@ -87,6 +59,9 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -96,37 +71,19 @@ type endpoint struct { rcvBufSize int rcvClosed bool - // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - // state must be read/set using the EndpointState()/setEndpointState() - // methods. - state uint32 - route *stack.Route `state:"manual"` - dstPort uint16 - ttl uint8 - multicastTTL uint8 - multicastAddr tcpip.Address - multicastNICID tcpip.NICID - portFlags ports.Flags - lastErrorMu sync.Mutex `state:"nosave"` lastError tcpip.Error + // The following fields are protected by the mu mutex. + mu sync.RWMutex `state:"nosave"` + portFlags ports.Flags + // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID boundPortFlags ports.Flags - // sendTOS represents IPv4 TOS or IPv6 TrafficClass, - // applied while sending packets. Defaults to 0 as on Linux. - sendTOS uint8 - - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - - // multicastMemberships that need to be remvoed when the endpoint is - // closed. Protected by the mu mutex. - multicastMemberships map[multicastMembership]struct{} + readShutdown bool // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -136,55 +93,25 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool -} -// +stateify savable -type multicastMembership struct { - nicID tcpip.NICID - multicastAddr tcpip.Address + localPort uint16 + remotePort uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: header.UDPProtocolNumber, - }, + stack: s, waiterQueue: waiterQueue, - // RFC 1075 section 5.4 recommends a TTL of 1 for membership - // requests. - // - // RFC 5135 4.2.1 appears to assume that IGMP messages have a - // TTL of 1. - // - // RFC 5135 Appendix A defines TTL=1: A multicast source that - // wants its traffic to not traverse a router (e.g., leave a - // home network) may find it useful to send traffic with IP - // TTL=1. - // - // Linux defaults to TTL=1. - multicastTTL: 1, - multicastMemberships: make(map[multicastMembership]struct{}), - state: uint32(StateInitial), - uniqueID: s.UniqueID(), + uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetSendBufferSize(32*1024, false /* notify */) e.ops.SetReceiveBufferSize(32*1024, false /* notify */) + e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -200,20 +127,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue return e } -// setEndpointState updates the state of the endpoint to state atomically. This -// method is unexported as the only place we should update the state is in this -// package but we allow the state to be read freely without holding e.mu. -// -// Precondition: e.mu must be held to call this method. -func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32(&e.state, uint32(state)) -} - -// EndpointState() returns the current state of the endpoint. -func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32(&e.state)) -} - // UniqueID implements stack.TransportEndpoint. func (e *endpoint) UniqueID() uint64 { return e.uniqueID @@ -244,16 +157,22 @@ func (e *endpoint) Abort() { // associated with it. func (e *endpoint) Close() { e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.EndpointState() { - case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + e.mu.Unlock() + return + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + id := e.net.Info().ID + id.LocalPort = e.localPort + id.RemotePort = e.remotePort + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice) portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: id.LocalAddress, + Port: id.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: tcpip.FullAddress{}, @@ -261,13 +180,10 @@ func (e *endpoint) Close() { e.stack.ReleasePort(portRes) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } - for mem := range e.multicastMemberships { - e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) - } - e.multicastMemberships = make(map[multicastMembership]struct{}) - // Close the receive list and drain it. e.rcvMu.Lock() e.rcvClosed = true @@ -278,14 +194,9 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - if e.route != nil { - e.route.Release() - e.route = nil - } - - // Update the state. - e.setEndpointState(StateClosed) - + e.net.Shutdown() + e.net.Close() + e.readShutdown = true e.mu.Unlock() e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) @@ -322,21 +233,38 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.receivedAt.UnixNano(), - } - if e.ops.GetReceiveTOS() { - cm.HasTOS = true - cm.TOS = p.tos - } - if e.ops.GetReceiveTClass() { - cm.HasTClass = true - // Although TClass is an 8-bit value it's read in the CMsg as a uint32. - cm.TClass = uint32(p.tos) + Timestamp: p.receivedAt, } - if e.ops.GetReceivePacketInfo() { - cm.HasIPPacketInfo = true - cm.PacketInfo = p.packetInfo + + switch p.netProto { + case header.IPv4ProtocolNumber: + if e.ops.GetReceiveTOS() { + cm.HasTOS = true + cm.TOS = p.tos + } + + if e.ops.GetReceivePacketInfo() { + cm.HasIPPacketInfo = true + cm.PacketInfo = p.packetInfo + } + case header.IPv6ProtocolNumber: + if e.ops.GetReceiveTClass() { + cm.HasTClass = true + // Although TClass is an 8-bit value it's read in the CMsg as a uint32. + cm.TClass = uint32(p.tos) + } + + if e.ops.GetIPv6ReceivePacketInfo() { + cm.HasIPv6PacketInfo = true + cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{ + NIC: p.packetInfo.NIC, + Addr: p.packetInfo.DestinationAddr, + } + } + default: + panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto)) } + if e.ops.GetReceiveOriginalDstAddress() { cm.HasOriginalDstAddress = true cm.OriginalDstAddress = p.destinationAddress @@ -359,19 +287,19 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult return res, nil } -// prepareForWrite prepares the endpoint for sending data. In particular, it -// binds it if it's still in the initial state. To do so, it must first +// prepareForWriteInner prepares the endpoint for sending data. In particular, +// it binds it if it's still in the initial state. To do so, it must first // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. // +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.EndpointState() { - case StateInitial: - case StateConnected: +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - case StateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -386,7 +314,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.EndpointState() != StateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -398,33 +326,6 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip return true, nil } -// connectRoute establishes a route to the specified interface or the -// configured multicast interface if no interface is specified and the -// specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { - localAddr := e.ID.LocalAddress - if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { - // A packet can only originate from a unicast address (i.e., an interface). - localAddr = "" - } - - if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicID == 0 { - nicID = e.multicastNICID - } - if localAddr == "" && nicID == 0 { - localAddr = e.multicastAddr - } - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) - if err != nil { - return nil, 0, err - } - return r, nicID, nil -} - // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { @@ -448,18 +349,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { +func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return udpPacketInfo{}, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(opts.To) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { return udpPacketInfo{}, err } @@ -469,49 +365,28 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions } } - route := e.route - dstPort := e.dstPort + dst, connected := e.net.GetRemoteAddress() + dst.Port = e.remotePort if opts.To != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := opts.To.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return udpPacketInfo{}, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - if opts.To.Port == 0 { // Port 0 is an invalid port to send to. return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{} } - dst, netProto, err := e.checkV4MappedLocked(*opts.To) - if err != nil { - return udpPacketInfo{}, err - } - - r, _, err := e.connectRoute(nicID, dst, netProto) - if err != nil { - return udpPacketInfo{}, err - } - defer r.Release() - - route = r - dstPort = dst.Port + dst = *opts.To + } else if !connected { + return udpPacketInfo{}, &tcpip.ErrDestinationRequired{} } - if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{} + ctx, err := e.net.AcquireContextForWrite(opts) + if err != nil { + return udpPacketInfo{}, err } + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { + ctx.Release() return udpPacketInfo{}, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { @@ -520,50 +395,25 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions if so.GetRecvError() { so.QueueLocalErr( &tcpip.ErrMessageTooLong{}, - route.NetProto(), + e.net.NetProto(), header.UDPMaximumPacketSize, - tcpip.FullAddress{ - NIC: route.NICID(), - Addr: route.RemoteAddress(), - Port: dstPort, - }, + dst, v, ) } + ctx.Release() return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} } - ttl := e.ttl - useDefaultTTL := ttl == 0 - if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { - ttl = e.multicastTTL - // Multicast allows a 0 TTL. - useDefaultTTL = false - } - return udpPacketInfo{ - route: route, - data: buffer.View(v), - localPort: e.ID.LocalPort, - remotePort: dstPort, - ttl: ttl, - useDefaultTTL: useDefaultTTL, - tos: e.sendTOS, - owner: e.owner, - noChecksum: e.SocketOptions().GetNoChecksum(), + ctx: ctx, + data: v, + localPort: e.localPort, + remotePort: dst.Port, }, nil } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if err := e.LastError(); err != nil { - return 0, err - } - - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - // Do not hold lock when sending as loopback is synchronous and if the UDP // datagram ends up generating an ICMP response then it can result in a // deadlock where the ICMP response handling ends up acquiring this endpoint's @@ -574,15 +424,53 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read // locking is prohibited. - u, err := e.buildUDPPacketInfo(p, opts) - if err != nil { + + if err := e.LastError(); err != nil { return 0, err } - n, err := u.send() + + udpInfo, err := e.prepareForWrite(p, opts) if err != nil { return 0, err } - return int64(n), nil + defer udpInfo.ctx.Release() + + pktInfo := udpInfo.ctx.PacketInfo() + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength), + Data: udpInfo.data.ToVectorisedView(), + }) + + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = ProtocolNumber + + length := uint16(pkt.Size()) + udp.Encode(&header.UDPFields{ + SrcPort: udpInfo.localPort, + DstPort: udpInfo.remotePort, + Length: length, + }) + + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if pktInfo.RequiresTXTransportChecksum && + (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) { + udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine( + header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length), + pkt.Data().AsRange().Checksum(), + ))) + } + if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + e.stack.Stats().UDP.PacketSendErrors.Increment() + return 0, err + } + + // Track count of packets sent. + e.stack.Stats().UDP.PacketsSent.Increment() + return int64(len(udpInfo.data)), nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler. @@ -601,36 +489,7 @@ func (e *endpoint) OnReusePortSet(v bool) { // SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.MTUDiscoverOption: - // Return not supported if the value is not disabling path - // MTU discovery. - if v != tcpip.PMTUDiscoveryDont { - return &tcpip.ErrNotSupported{} - } - - case tcpip.MulticastTTLOption: - e.mu.Lock() - e.multicastTTL = uint8(v) - e.mu.Unlock() - - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - case tcpip.IPv4TOSOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - } - - return nil + return e.net.SetSockOptInt(opt, v) } var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) @@ -642,145 +501,12 @@ func (e *endpoint) HasNIC(id int32) bool { // SetSockOpt implements tcpip.Endpoint. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { - switch v := opt.(type) { - case *tcpip.MulticastInterfaceOption: - e.mu.Lock() - defer e.mu.Unlock() - - fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedLocked(fa) - if err != nil { - return err - } - nic := v.NIC - addr := fa.Addr - - if nic == 0 && addr == "" { - e.multicastAddr = "" - e.multicastNICID = 0 - break - } - - if nic != 0 { - if !e.stack.CheckNIC(nic) { - return &tcpip.ErrBadLocalAddress{} - } - } else { - nic = e.stack.CheckLocalAddress(0, netProto, addr) - if nic == 0 { - return &tcpip.ErrBadLocalAddress{} - } - } - - if e.BindNICID != 0 && e.BindNICID != nic { - return &tcpip.ErrInvalidEndpointState{} - } - - e.multicastNICID = nic - e.multicastAddr = addr - - case *tcpip.AddMembershipOption: - if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return &tcpip.ErrInvalidOptionValue{} - } - - nicID := v.NIC - - if v.InterfaceAddr.Unspecified() { - if nicID == 0 { - if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil { - nicID = r.NICID() - r.Release() - } - } - } else { - nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) - } - if nicID == 0 { - return &tcpip.ErrUnknownDevice{} - } - - memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - - e.mu.Lock() - defer e.mu.Unlock() - - if _, ok := e.multicastMemberships[memToInsert]; ok { - return &tcpip.ErrPortInUse{} - } - - if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { - return err - } - - e.multicastMemberships[memToInsert] = struct{}{} - - case *tcpip.RemoveMembershipOption: - if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return &tcpip.ErrInvalidOptionValue{} - } - - nicID := v.NIC - if v.InterfaceAddr.Unspecified() { - if nicID == 0 { - if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil { - nicID = r.NICID() - r.Release() - } - } - } else { - nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) - } - if nicID == 0 { - return &tcpip.ErrUnknownDevice{} - } - - memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - - e.mu.Lock() - defer e.mu.Unlock() - - if _, ok := e.multicastMemberships[memToRemove]; !ok { - return &tcpip.ErrBadLocalAddress{} - } - - if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { - return err - } - - delete(e.multicastMemberships, memToRemove) - - case *tcpip.SocketDetachFilterOption: - return nil - } - return nil + return e.net.SetSockOpt(opt) } // GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { - case tcpip.IPv4TOSOption: - e.mu.RLock() - v := int(e.sendTOS) - e.mu.RUnlock() - return v, nil - - case tcpip.IPv6TrafficClassOption: - e.mu.RLock() - v := int(e.sendTOS) - e.mu.RUnlock() - return v, nil - - case tcpip.MTUDiscoverOption: - // The only supported setting is path MTU discovery disabled. - return tcpip.PMTUDiscoveryDont, nil - - case tcpip.MulticastTTLOption: - e.mu.Lock() - v := int(e.multicastTTL) - e.mu.Unlock() - return v, nil - case tcpip.ReceiveQueueSizeOption: v := 0 e.rcvMu.Lock() @@ -791,108 +517,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.mu.Lock() - v := int(e.ttl) - e.mu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } // GetSockOpt implements tcpip.Endpoint. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { - switch o := opt.(type) { - case *tcpip.MulticastInterfaceOption: - e.mu.Lock() - *o = tcpip.MulticastInterfaceOption{ - NIC: e.multicastNICID, - InterfaceAddr: e.multicastAddr, - } - e.mu.Unlock() - - default: - return &tcpip.ErrUnknownProtocolOption{} - } - return nil + return e.net.GetSockOpt(opt) } -// udpPacketInfo contains all information required to send a UDP packet. -// -// This should be used as a value-only type, which exists in order to simplify -// return value syntax. It should not be exported or extended. +// udpPacketInfo holds information needed to send a UDP packet. type udpPacketInfo struct { - route *stack.Route - data buffer.View - localPort uint16 - remotePort uint16 - ttl uint8 - useDefaultTTL bool - tos uint8 - owner tcpip.PacketOwner - noChecksum bool -} - -// send sends the given packet. -func (u *udpPacketInfo) send() (int, tcpip.Error) { - vv := u.data.ToVectorisedView() - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()), - Data: vv, - }) - pkt.Owner = u.owner - - // Initialize the UDP header. - udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - pkt.TransportProtocolNumber = ProtocolNumber - - length := uint16(pkt.Size()) - udp.Encode(&header.UDPFields{ - SrcPort: u.localPort, - DstPort: u.remotePort, - Length: length, - }) - - // Set the checksum field unless TX checksum offload is enabled. - // On IPv4, UDP checksum is optional, and a zero value indicates the - // transmitter skipped the checksum generation (RFC768). - // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if u.route.RequiresTXTransportChecksum() && - (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) { - xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length) - for _, v := range vv.Views() { - xsum = header.Checksum(v, xsum) - } - udp.SetChecksum(^udp.CalculateChecksum(xsum)) - } - - if u.useDefaultTTL { - u.ttl = u.route.DefaultTTL() - } - if err := u.route.WritePacket(stack.NetworkHeaderParams{ - Protocol: ProtocolNumber, - TTL: u.ttl, - TOS: u.tos, - }, pkt); err != nil { - u.route.Stats().UDP.PacketSendErrors.Increment() - return 0, err - } - - // Track count of packets sent. - u.route.Stats().UDP.PacketsSent.Increment() - return len(u.data), nil -} - -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil + ctx network.WriteContext + data buffer.View + localPort uint16 + remotePort uint16 } // Disconnect implements tcpip.Endpoint. @@ -900,7 +540,7 @@ func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.EndpointState() != StateConnected { + if e.net.State() != transport.DatagramEndpointStateConnected { return nil } var ( @@ -913,26 +553,28 @@ func (e *endpoint) Disconnect() tcpip.Error { boundPortFlags := e.boundPortFlags // Exclude ephemerally bound endpoints. - if e.BindNICID != 0 || e.ID.LocalAddress == "" { + info := e.net.Info() + info.ID.LocalPort = e.localPort + info.ID.RemotePort = e.remotePort + if e.net.WasBound() { var err tcpip.Error id = stack.TransportEndpointID{ - LocalPort: e.ID.LocalPort, - LocalAddress: e.ID.LocalAddress, + LocalPort: info.ID.LocalPort, + LocalAddress: info.ID.LocalAddress, } id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { return err } - e.setEndpointState(StateBound) boundPortFlags = e.boundPortFlags } else { - if e.ID.LocalPort != 0 { + if info.ID.LocalPort != 0 { // Release the ephemeral port. portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: info.ID.LocalAddress, + Port: info.ID.LocalPort, Flags: boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: tcpip.FullAddress{}, @@ -940,15 +582,14 @@ func (e *endpoint) Disconnect() tcpip.Error { e.stack.ReleasePort(portRes) e.boundPortFlags = ports.Flags{} } - e.setEndpointState(StateInitial) } - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) - e.ID = id + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice) e.boundBindToDevice = btd - e.route.Release() - e.route = nil - e.dstPort = 0 + e.localPort = id.LocalPort + e.remotePort = id.RemotePort + + e.net.Disconnect() return nil } @@ -958,88 +599,48 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - var localPort uint16 - switch e.EndpointState() { - case StateInitial: - case StateBound, StateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } - - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.localPort + nextID.RemotePort = addr.Port + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv4ProtocolNumber, + header.IPv6ProtocolNumber, + } } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } + oldPortFlags := e.boundPortFlags - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - r, nicID, err := e.connectRoute(nicID, addr, netProto) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: e.ID.LocalAddress, - LocalPort: localPort, - RemotePort: addr.Port, - RemoteAddress: r.RemoteAddress(), - } - - if e.EndpointState() == StateInitial { - id.LocalAddress = r.LocalAddress() - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv4ProtocolNumber, - header.IPv6ProtocolNumber, + nextID, btd, err := e.registerWithStack(netProtos, nextID) + if err != nil { + return err } - } - oldPortFlags := e.boundPortFlags + // Remove the old registration. + if e.localPort != 0 { + previousID.LocalPort = e.localPort + previousID.RemotePort = e.remotePort + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice) + } - id, btd, err := e.registerWithStack(netProtos, id) + e.localPort = nextID.LocalPort + e.remotePort = nextID.RemotePort + e.boundBindToDevice = btd + e.effectiveNetProtos = netProtos + return nil + }) if err != nil { - r.Release() return err } - // Remove the old registration. - if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) - } - - e.ID = id - e.boundBindToDevice = btd - if e.route != nil { - // If the endpoint was already connected then make sure we release the - // previous route. - e.route.Release() - } - e.route = r - e.dstPort = addr.Port - e.RegisterNICID = nicID - e.effectiveNetProtos = netProtos - - e.setEndpointState(StateConnected) - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() - return nil } @@ -1054,15 +655,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - // A socket in the bound state can still receive multicast messages, - // so we need to notify waiters on shutdown. - if state := e.EndpointState(); state != StateBound && state != StateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } - e.shutdownFlags |= flags + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } + } if flags&tcpip.ShutdownRead != 0 { + e.readShutdown = true + e.rcvMu.Lock() wasClosed := e.rcvClosed e.rcvClosed = true @@ -1088,7 +697,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - if e.ID.LocalPort == 0 { + if e.localPort == 0 { portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, @@ -1126,56 +735,43 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.EndpointState() != StateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{boundNetProto} + if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } } - } - nicID := addr.NIC - if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { - // A local unicast address was specified, verify that it's valid. - nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicID == 0 { - return &tcpip.ErrBadLocalAddress{} + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: boundAddr, + } + id, btd, err := e.registerWithStack(netProtos, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, btd, err := e.registerWithStack(netProtos, id) + e.localPort = id.LocalPort + e.boundBindToDevice = btd + e.effectiveNetProtos = netProtos + return nil + }) if err != nil { return err } - e.ID = id - e.boundBindToDevice = btd - e.RegisterNICID = nicID - e.effectiveNetProtos = netProtos - - // Mark endpoint as bound. - e.setEndpointState(StateBound) - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() - return nil } @@ -1190,9 +786,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { return err } - // Save the effective NICID generated by bindLocked. - e.BindNICID = e.RegisterNICID - return nil } @@ -1201,16 +794,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - addr := e.ID.LocalAddress - if e.EndpointState() == StateConnected { - addr = e.route.LocalAddress() - } - - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: addr, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.localPort + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -1218,15 +804,13 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.EndpointState() != StateConnected || e.dstPort == 0 { + addr, connected := e.net.GetRemoteAddress() + if !connected || e.remotePort == 0 { return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + addr.Port = e.remotePort + return addr, nil } // Readiness returns the current readiness of the endpoint. For example, if @@ -1321,6 +905,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB // Push new packet into receive list and increment the buffer size. packet := &udpPacket{ + netProto: pkt.NetworkProtocolNumber, senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, @@ -1376,19 +961,20 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p payload = udp.Payload() } + id := e.net.Info().ID e.SocketOptions().QueueErr(&tcpip.SockError{ Err: err, Cause: transErr, Payload: payload, Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: id.RemoteAddress, + Port: e.remotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: id.LocalAddress, + Port: e.localPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -1403,7 +989,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // TODO(gvisor.dev/issues/5270): Handle all transport errors. switch transErr.Kind() { case stack.DestinationPortUnreachableTransportError: - if e.EndpointState() == StateConnected { + if e.net.State() == transport.DatagramEndpointStateConnected { e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) } } @@ -1411,16 +997,17 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // State implements tcpip.Endpoint. func (e *endpoint) State() uint32 { - return uint32(e.EndpointState()) + return uint32(e.net.State()) } // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() - return &ret + defer e.mu.RUnlock() + info := e.net.Info() + info.ID.LocalPort = e.localPort + info.ID.RemotePort = e.remotePort + return &info } // Stats returns a pointer to the endpoint stats. @@ -1431,13 +1018,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements tcpip.Endpoint. func (*endpoint) Wait() {} -func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) -} - // SetOwner implements tcpip.Endpoint. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // SocketOptions implements tcpip.Endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 1f638c3f6..2ff8b0482 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,12 +15,13 @@ package udp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -35,17 +36,11 @@ func (p *udpPacket) loadReceivedAt(nsec int64) { // saveData saves udpPacket.data field. func (p *udpPacket) saveData() buffer.VectorisedView { - // We cannot save p.data directly as p.data.views may alias to p.views, - // which is not allowed by state framework (in-struct pointer). return p.data.Clone(nil) } // loadData loads udpPacket.data field. func (p *udpPacket) loadData(data buffer.VectorisedView) { - // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization - // here because data.views is not guaranteed to be loaded by now. Plus, - // data.views will be allocated anyway so there really is little point - // of utilizing p.views for data.views. p.data = data } @@ -66,50 +61,28 @@ func (e *endpoint) Resume(s *stack.Stack) { e.mu.Lock() defer e.mu.Unlock() + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - for m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } - - state := e.EndpointState() - if state != StateBound && state != StateConnected { - return - } - - netProto := e.effectiveNetProtos[0] - // Connect() and bindLocked() both assert - // - // netProto == header.IPv6ProtocolNumber - // - // before creating a multi-entry effectiveNetProtos. - if len(e.effectiveNetProtos) > 1 { - netProto = header.IPv6ProtocolNumber - } - - var err tcpip.Error - if state == StateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + var err tcpip.Error + id := e.net.Info().ID + id.LocalPort = e.localPort + id.RemotePort = e.remotePort + id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { panic(err) } - } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound - // A local unicast address is specified, verify that it's valid. - if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - // Our saved state had a port, but we don't actually have a - // reservation. We need to remove the port from our state, but still - // pass it to the reservation machinery. - id := e.ID - e.ID.LocalPort = 0 - e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) - if err != nil { - panic(err) + e.localPort = id.LocalPort + e.remotePort = id.RemotePort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 7c357cb09..7238fc019 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -70,28 +70,29 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) + ep.mu.Lock() + defer ep.mu.Unlock() + netHdr := r.pkt.Network() - route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) - if err != nil { + if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil { + return nil, err + } + + if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil { return nil, err } - ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() - route.Release() return nil, err } - ep.ID = r.id - ep.route = route - ep.dstPort = r.id.RemotePort + ep.localPort = r.id.LocalPort + ep.remotePort = r.id.RemotePort ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} - ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = uint32(StateConnected) - ep.rcvMu.Lock() ep.rcvReady = true ep.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4008cacf2..b3199489c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -290,6 +291,7 @@ type testContext struct { t *testing.T linkEP *channel.Endpoint s *stack.Stack + nicID tcpip.NICID ep tcpip.Endpoint wq waiter.Queue @@ -301,6 +303,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { } func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext { + const nicID = 1 + t.Helper() options := stack.Options{ @@ -310,38 +314,50 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo Clock: &faketime.NullClock{}, } s := stack.New(options) + // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus + // never allows ICMP messages. + s.SetICMPLimit(rate.Inf) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) + if err := s.CreateNIC(nicID, wep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %s", err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) } - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %s", err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) } s.SetRouteTable([]tcpip.Route{ { Destination: header.IPv4EmptySubnet, - NIC: 1, + NIC: nicID, }, { Destination: header.IPv6EmptySubnet, - NIC: 1, + NIC: nicID, }, }) return &testContext{ t: t, s: s, + nicID: nicID, linkEP: ep, } } @@ -1353,64 +1369,70 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { func TestReadIPPacketInfo(t *testing.T) { tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedLocalAddr tcpip.Address - expectedDestAddr tcpip.Address + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + checker func(tcpip.NICID) checker.ControlMessagesChecker }{ { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedLocalAddr: stackAddr, - expectedDestAddr: stackAddr, + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + LocalAddr: stackAddr, + DestinationAddr: stackAddr, + }) + }, }, { name: "IPv4 multicast", proto: header.IPv4ProtocolNumber, flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastAddr, - expectedDestAddr: multicastAddr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + // TODO(gvisor.dev/issue/3556): Check for a unicast address. + LocalAddr: multicastAddr, + DestinationAddr: multicastAddr, + }) + }, }, { name: "IPv4 broadcast", proto: header.IPv4ProtocolNumber, flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: broadcastAddr, - expectedDestAddr: broadcastAddr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: id, + // TODO(gvisor.dev/issue/3556): Check for a unicast address. + LocalAddr: broadcastAddr, + DestinationAddr: broadcastAddr, + }) + }, }, { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedLocalAddr: stackV6Addr, - expectedDestAddr: stackV6Addr, + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{ + NIC: id, + Addr: stackV6Addr, + }) + }, }, { name: "IPv6 multicast", proto: header.IPv6ProtocolNumber, flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastV6Addr, - expectedDestAddr: multicastV6Addr, + checker: func(id tcpip.NICID) checker.ControlMessagesChecker { + return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{ + NIC: id, + Addr: multicastV6Addr, + }) + }, }, } @@ -1433,13 +1455,16 @@ func TestReadIPPacketInfo(t *testing.T) { } } - c.ep.SocketOptions().SetReceivePacketInfo(true) + switch f := test.flow.netProto(); f { + case header.IPv4ProtocolNumber: + c.ep.SocketOptions().SetReceivePacketInfo(true) + case header.IPv6ProtocolNumber: + c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true) + default: + t.Fatalf("unhandled protocol number = %d", f) + } - testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ - NIC: 1, - LocalAddr: test.expectedLocalAddr, - DestinationAddr: test.expectedDestAddr, - })) + testRead(c, test.flow, test.checker(c.nicID)) if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) @@ -1644,8 +1669,10 @@ func TestSetTTL(t *testing.T) { } } +var v4PacketFlows = [...]testFlow{unicastV4, multicastV4, broadcast, unicastV4in6, multicastV4in6, broadcastIn6} + func TestSetTOS(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + for _, flow := range v4PacketFlows { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1680,8 +1707,10 @@ func TestSetTOS(t *testing.T) { } } +var v6PacketFlows = [...]testFlow{unicastV6, unicastV6Only, multicastV6} + func TestSetTClass(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { + for _, flow := range v6PacketFlows { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1725,8 +1754,14 @@ func TestReceiveTosTClass(t *testing.T) { name string tests []testFlow }{ - {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, - {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + { + name: RcvTOSOpt, + tests: v4PacketFlows[:], + }, + { + name: RcvTClassOpt, + tests: v6PacketFlows[:], + }, } for _, testCase := range testCases { for _, flow := range testCase.tests { @@ -1737,6 +1772,14 @@ func TestReceiveTosTClass(t *testing.T) { c.createEndpointForFlow(flow) name := testCase.name + if flow.isMulticast() { + netProto := flow.netProto() + addr := flow.getMcastAddr() + if err := c.s.JoinGroup(netProto, c.nicID, addr); err != nil { + c.t.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, c.nicID, addr, err) + } + } + var optionGetter func() bool var optionSetter func(bool) switch name { @@ -2482,8 +2525,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index 234125c38..8902be2d3 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/eventfd", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go index 40fa72925..0dc0c37bd 100644 --- a/pkg/unet/unet.go +++ b/pkg/unet/unet.go @@ -23,6 +23,7 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/eventfd" "gvisor.dev/gvisor/pkg/sync" ) @@ -55,15 +56,6 @@ func socket(packet bool) (int, error) { return fd, nil } -// eventFD returns a new event FD with initial value 0. -func eventFD() (int, error) { - f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0) - if e != 0 { - return -1, e - } - return int(f), nil -} - // Socket is a connected unix domain socket. type Socket struct { // gate protects use of fd. @@ -78,7 +70,7 @@ type Socket struct { // efd is an event FD that is signaled when the socket is closing. // // efd is immutable and remains valid until Close/Release. - efd int + efd eventfd.Eventfd // race is an atomic variable used to avoid triggering the race // detector. See comment in SocketPair below. @@ -95,7 +87,7 @@ func NewSocket(fd int) (*Socket, error) { return nil, err } - efd, err := eventFD() + efd, err := eventfd.Create() if err != nil { return nil, err } @@ -110,16 +102,14 @@ func NewSocket(fd int) (*Socket, error) { // closing the event FD. func (s *Socket) finish() error { // Signal any blocked or future polls. - // - // N.B. eventfd writes must be 8 bytes. - if _, err := unix.Write(s.efd, []byte{1, 0, 0, 0, 0, 0, 0, 0}); err != nil { + if err := s.efd.Notify(); err != nil { return err } // Close the gate, blocking until all FD users leave. s.gate.Close() - return unix.Close(s.efd) + return s.efd.Close() } // Close closes the socket. diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go index f0bf93ddd..ea281fec3 100644 --- a/pkg/unet/unet_unsafe.go +++ b/pkg/unet/unet_unsafe.go @@ -43,7 +43,7 @@ func (s *Socket) wait(write bool) error { }, { // The eventfd, signaled when we are closing. - Fd: int32(s.efd), + Fd: int32(s.efd.FD()), Events: unix.POLLIN, }, } diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index cde1038ed..f46a00e42 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -429,7 +429,7 @@ type IOSequence struct { // return 0, nil // } // if f.availableBytes == 0 { -// return 0, syserror.ErrWouldBlock +// return 0, linuxerr.ErrWouldBlock // } // return ioseq.CopyOutFrom(..., reader) // |