diff options
Diffstat (limited to 'pkg')
304 files changed, 11281 insertions, 3605 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index ba233b93f..f45934466 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -2,9 +2,11 @@ # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix # when the host OS may not be Linux. +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "linux", @@ -44,6 +46,7 @@ go_library( "sem.go", "shm.go", "signal.go", + "signalfd.go", "socket.go", "splice.go", "tcp.go", diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 7d742871a..257f67222 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -271,7 +271,7 @@ type Statx struct { } // FileMode represents a mode_t. -type FileMode uint +type FileMode uint16 // Permissions returns just the permission bits. func (m FileMode) Permissions() FileMode { diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go new file mode 100644 index 000000000..85fad9956 --- /dev/null +++ b/pkg/abi/linux/signalfd.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +const ( + // SFD_NONBLOCK is a signalfd(2) flag. + SFD_NONBLOCK = 00004000 + + // SFD_CLOEXEC is a signalfd(2) flag. + SFD_CLOEXEC = 02000000 +) + +// SignalfdSiginfo is the siginfo encoding for signalfds. +type SignalfdSiginfo struct { + Signo uint32 + Errno int32 + Code int32 + PID uint32 + UID uint32 + FD int32 + TID uint32 + Band uint32 + Overrun uint32 + TrapNo uint32 + Status int32 + Int int32 + Ptr uint64 + UTime uint64 + STime uint64 + Addr uint64 + AddrLSB uint16 + _ [48]uint8 +} diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 39d253b98..6bc486b62 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 47ab65346..36beaade9 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -7,6 +8,7 @@ go_library( srcs = [ "atomic_bitops.go", "atomic_bitops_amd64.s", + "atomic_bitops_arm64.s", "atomic_bitops_common.go", ], importpath = "gvisor.dev/gvisor/pkg/atomicbitops", diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomic_bitops.go index 63aa2b7f1..fcc41a9ea 100644 --- a/pkg/atomicbitops/atomic_bitops.go +++ b/pkg/atomicbitops/atomic_bitops.go @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 // Package atomicbitops provides basic bitwise operations in an atomic way. // The implementation on amd64 leverages the LOCK prefix directly instead of -// relying on the generic cas primitives. +// relying on the generic cas primitives, and the arm64 leverages the LDAXR +// and STLXR pair primitives. // // WARNING: the bitwise ops provided in this package doesn't imply any memory // ordering. Using them to construct locks must employ proper memory barriers. diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomic_bitops_arm64.s new file mode 100644 index 000000000..97f8808c1 --- /dev/null +++ b/pkg/atomicbitops/atomic_bitops_arm64.s @@ -0,0 +1,139 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +#include "textflag.h" + +TEXT ·AndUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + ANDW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·OrUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + ORRW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·XorUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + EORW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·CompareAndSwapUint32(SB),$0-20 + MOVD addr+0(FP), R0 + MOVW old+8(FP), R1 + MOVW new+12(FP), R2 + +again: + LDAXRW (R0), R3 + CMPW R1, R3 + BNE done + STLXRW R2, (R0), R4 + CBNZ R4, again +done: + MOVW R3, prev+16(FP) + RET + +TEXT ·AndUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + AND R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·OrUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + ORR R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·XorUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + EOR R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·CompareAndSwapUint64(SB),$0-32 + MOVD addr+0(FP), R0 + MOVD old+8(FP), R1 + MOVD new+16(FP), R2 + +again: + LDAXR (R0), R3 + CMP R1, R3 + BNE done + STLXR R2, (R0), R4 + CBNZ R4, again +done: + MOVD R3, prev+24(FP) + RET + +TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9 + MOVD addr+0(FP), R0 + +again: + LDAXRW (R0), R1 + CBZ R1, fail + ADDW $1, R1 + STLXRW R1, (R0), R2 + CBNZ R2, again + MOVW $1, R2 + MOVB R2, ret+8(FP) + RET +fail: + MOVB ZR, ret+8(FP) + RET + +TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9 + MOVD addr+0(FP), R0 + +again: + LDAXRW (R0), R1 + SUBSW $1, R1, R1 + BEQ fail + STLXRW R1, (R0), R2 + CBNZ R2, again + MOVW $1, R2 + MOVB R2, ret+8(FP) + RET +fail: + MOVB ZR, ret+8(FP) + RET diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomic_bitops_common.go index b2a943dcb..85163ad62 100644 --- a/pkg/atomicbitops/atomic_bitops_common.go +++ b/pkg/atomicbitops/atomic_bitops_common.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !amd64 +// +build !amd64,!arm64 package atomicbitops diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD index 09d6c2c1f..543fb54bf 100644 --- a/pkg/binary/BUILD +++ b/pkg/binary/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 0c2dde4f8..1b5dac99a 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -10,8 +11,9 @@ go_library( "bits.go", "bits32.go", "bits64.go", - "uint64_arch_amd64.go", + "uint64_arch.go", "uint64_arch_amd64_asm.s", + "uint64_arch_arm64_asm.s", "uint64_arch_generic.go", ], importpath = "gvisor.dev/gvisor/pkg/bits", diff --git a/pkg/bits/uint64_arch_amd64.go b/pkg/bits/uint64_arch.go index faccaa61a..9f23eff77 100644 --- a/pkg/bits/uint64_arch_amd64.go +++ b/pkg/bits/uint64_arch.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 package bits diff --git a/pkg/bits/uint64_arch_arm64_asm.s b/pkg/bits/uint64_arch_arm64_asm.s new file mode 100644 index 000000000..814ba562d --- /dev/null +++ b/pkg/bits/uint64_arch_arm64_asm.s @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +TEXT ·TrailingZeros64(SB),$0-16 + MOVD x+0(FP), R0 + RBIT R0, R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD R0, ret+8(FP) + RET + +TEXT ·MostSignificantOne64(SB),$0-16 + MOVD x+0(FP), R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD $63, R1 + SUBS R0, R1, R0 // ret = 63 - CLZ + BPL end + MOVD $64, R0 // x == 0 +end: + MOVD R0, ret+8(FP) + RET diff --git a/pkg/bits/uint64_arch_generic.go b/pkg/bits/uint64_arch_generic.go index 7dd2d1480..9dd2098d1 100644 --- a/pkg/bits/uint64_arch_generic.go +++ b/pkg/bits/uint64_arch_generic.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !amd64 +// +build !amd64,!arm64 package bits diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index b692aa3b1..8d31e068c 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "bpf", diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index cdec96df1..a0b21d4bd 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index 830e19e07..32422f9e2 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "cpuid", diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index 3fabaf445..5d61dc2ff 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -418,6 +418,73 @@ var x86FeatureParseOnlyStrings = map[Feature]string{ X86FeaturePREFETCHWT1: "prefetchwt1", } +// intelCacheDescriptors describe the caches and TLBs on the system. They are +// returned in the registers for eax=2. Intel only. +type intelCacheDescriptor uint8 + +// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol. +// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors". +const ( + intelNullDescriptor intelCacheDescriptor = 0 + intelNoTLBDescriptor intelCacheDescriptor = 0xfe + intelNoCacheDescriptor intelCacheDescriptor = 0xff + + // Most descriptors omitted for brevity as they are currently unused. +) + +// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4. +type CacheType uint8 + +const ( + // cacheNull indicates that there are no more entries. + cacheNull CacheType = iota + + // CacheData is a data cache. + CacheData + + // CacheInstruction is an instruction cache. + CacheInstruction + + // CacheUnified is a unified instruction and data cache. + CacheUnified +) + +// Cache describes the parameters of a single cache on the system. +// +// +stateify savable +type Cache struct { + // Level is the hierarchical level of this cache (L1, L2, etc). + Level uint32 + + // Type is the type of cache. + Type CacheType + + // FullyAssociative indicates that entries may be placed in any block. + FullyAssociative bool + + // Partitions is the number of physical partitions in the cache. + Partitions uint32 + + // Ways is the number of ways of associativity in the cache. + Ways uint32 + + // Sets is the number of sets in the cache. + Sets uint32 + + // InvalidateHierarchical indicates that WBINVD/INVD from threads + // sharing this cache acts upon lower level caches for threads sharing + // this cache. + InvalidateHierarchical bool + + // Inclusive indicates that this cache is inclusive of lower cache + // levels. + Inclusive bool + + // DirectMapped indicates that this cache is directly mapped from + // address, rather than using a hash function. + DirectMapped bool +} + // Just a way to wrap cpuid function numbers. type cpuidFunction uint32 @@ -494,7 +561,7 @@ func (f Feature) flagString(cpuinfoOnly bool) string { return "" } -// FeatureSet is a set of Features for a cpu. +// FeatureSet is a set of Features for a CPU. // // +stateify savable type FeatureSet struct { @@ -521,6 +588,15 @@ type FeatureSet struct { // SteppingID is part of the processor signature. SteppingID uint8 + + // Caches describes the caches on the CPU. + Caches []Cache + + // CacheLine is the size of a cache line in bytes. + // + // All caches use the same line size. This is not enforced in the CPUID + // encoding, but is true on all known x86 processors. + CacheLine uint32 } // FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is @@ -557,22 +633,27 @@ func (fs FeatureSet) CPUInfo(cpu uint) string { fmt.Fprintln(&b, "wp\t\t: yes") fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true)) fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway. - fmt.Fprintf(&b, "clflush size\t: %d\n", 64) - fmt.Fprintf(&b, "cache_alignment\t: %d\n", 64) + fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine) + fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine) fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48) fmt.Fprintln(&b, "power management:") // This is always here, but can be blank. fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline. return b.String() } +const ( + amdVendorID = "AuthenticAMD" + intelVendorID = "GenuineIntel" +) + // AMD returns true if fs describes an AMD CPU. func (fs *FeatureSet) AMD() bool { - return fs.VendorID == "AuthenticAMD" + return fs.VendorID == amdVendorID } // Intel returns true if fs describes an Intel CPU. func (fs *FeatureSet) Intel() bool { - return fs.VendorID == "GenuineIntel" + return fs.VendorID == intelVendorID } // ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a @@ -589,9 +670,18 @@ func (e ErrIncompatible) Error() string { // CheckHostCompatible returns nil if fs is a subset of the host feature set. func (fs *FeatureSet) CheckHostCompatible() error { hfs := HostFeatureSet() + if diff := fs.Subtract(hfs); diff != nil { return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)} } + + // The size of a cache line must match, as it is critical to correctly + // utilizing CLFLUSH. Other cache properties are allowed to change, as + // they are not important to correctness. + if fs.CacheLine != hfs.CacheLine { + return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)} + } + return nil } @@ -732,14 +822,6 @@ func (fs *FeatureSet) HasFeature(feature Feature) bool { return fs.Set[feature] } -// IsSubset returns true if the FeatureSet is a subset of the FeatureSet passed in. -// This is useful if you want to see if a FeatureSet is compatible with another -// FeatureSet, since you can only run with a given FeatureSet if it's a subset of -// the host's. -func (fs *FeatureSet) IsSubset(other *FeatureSet) bool { - return fs.Subtract(other) == nil -} - // Subtract returns the features present in fs that are not present in other. // If all features in fs are present in other, Subtract returns nil. func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) { @@ -755,17 +837,6 @@ func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) { return } -// TakeFeatureIntersection will set the features in `fs` to the intersection of -// the features in `fs` and `other` (effectively clearing any feature bits on -// `fs` that are not also set in `other`). -func (fs *FeatureSet) TakeFeatureIntersection(other *FeatureSet) { - for f := range fs.Set { - if !other.Set[f] { - delete(fs.Set, f) - } - } -} - // EmulateID emulates a cpuid instruction based on the feature set. func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) { switch cpuidFunction(origAx) { @@ -773,9 +844,8 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) { ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support. bx, dx, cx = fs.vendorIDRegs() case featureInfo: - // clflush line size (ebx bits[15:8]) hardcoded as 8. This - // means cache lines of size 64 bytes. - bx = 8 << 8 + // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported. + bx = (fs.CacheLine / 8) << 8 cx = fs.blockMask(block(0)) dx = fs.blockMask(block(1)) ax = fs.signature() @@ -789,10 +859,46 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) { // will always return 01H. Software should ignore this value // and not interpret it as an informational descriptor." - SDM // - // We do not support exposing cache information, but we do set - // this fixed field because some language runtimes (dlang) get - // confused by ax = 0 and will loop infinitely. - ax = 1 + // We only support reporting cache parameters via + // intelDeterministicCacheParams; report as much here. + // + // We do not support exposing TLB information at all. + ax = 1 | (uint32(intelNoCacheDescriptor) << 8) + case intelDeterministicCacheParams: + if !fs.Intel() { + // Reserved on non-Intel. + return 0, 0, 0, 0 + } + + // cx is the index of the cache to describe. + if int(origCx) >= len(fs.Caches) { + return uint32(cacheNull), 0, 0, 0 + } + c := fs.Caches[origCx] + + ax = uint32(c.Type) + ax |= c.Level << 5 + ax |= 1 << 8 // Always claim the cache is "self-initializing". + if c.FullyAssociative { + ax |= 1 << 9 + } + // Processor topology not supported. + + bx = fs.CacheLine - 1 + bx |= (c.Partitions - 1) << 12 + bx |= (c.Ways - 1) << 22 + + cx = c.Sets - 1 + + if !c.InvalidateHierarchical { + dx |= 1 + } + if c.Inclusive { + dx |= 1 << 1 + } + if !c.DirectMapped { + dx |= 1 << 2 + } case xSaveInfo: if !fs.UseXsave() { return 0, 0, 0, 0 @@ -845,10 +951,41 @@ func HostFeatureSet() *FeatureSet { vendorID := vendorIDFromRegs(bx, cx, dx) // eax=1 gets basic features in ecx:edx. - ax, _, cx, dx := HostID(1, 0) + ax, bx, cx, dx := HostID(1, 0) featureBlock0 := cx featureBlock1 := dx ef, em, pt, f, m, sid := signatureSplit(ax) + cacheLine := 8 * (bx >> 8) & 0xff + + // eax=4, ecx=i gets details about cache index i. Only supported on Intel. + var caches []Cache + if vendorID == intelVendorID { + // ecx selects the cache index until a null type is returned. + for i := uint32(0); ; i++ { + ax, bx, cx, dx := HostID(4, i) + t := CacheType(ax & 0xf) + if t == cacheNull { + break + } + + lineSize := (bx & 0xfff) + 1 + if lineSize != cacheLine { + panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine)) + } + + caches = append(caches, Cache{ + Type: t, + Level: (ax >> 5) & 0x7, + FullyAssociative: ((ax >> 9) & 1) == 1, + Partitions: ((bx >> 12) & 0x3ff) + 1, + Ways: ((bx >> 22) & 0x3ff) + 1, + Sets: cx + 1, + InvalidateHierarchical: (dx & 1) == 0, + Inclusive: ((dx >> 1) & 1) == 1, + DirectMapped: ((dx >> 2) & 1) == 0, + }) + } + } // eax=7, ecx=0 gets extended features in ecx:ebx. _, bx, cx, _ = HostID(7, 0) @@ -883,6 +1020,8 @@ func HostFeatureSet() *FeatureSet { Family: f, Model: m, SteppingID: sid, + CacheLine: cacheLine, + Caches: caches, } } diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_test.go index 6ae14d2da..a707ebb55 100644 --- a/pkg/cpuid/cpuid_test.go +++ b/pkg/cpuid/cpuid_test.go @@ -57,24 +57,13 @@ var justFPUandPAE = &FeatureSet{ X86FeaturePAE: true, }} -func TestIsSubset(t *testing.T) { - if !justFPU.IsSubset(justFPUandPAE) { - t.Errorf("Got %v is not subset of %v, want IsSubset being true", justFPU, justFPUandPAE) +func TestSubtract(t *testing.T) { + if diff := justFPU.Subtract(justFPUandPAE); diff != nil { + t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff) } - if justFPUandPAE.IsSubset(justFPU) { - t.Errorf("Got %v is a subset of %v, want IsSubset being false", justFPU, justFPUandPAE) - } -} - -func TestTakeFeatureIntersection(t *testing.T) { - testFeatures := HostFeatureSet() - testFeatures.TakeFeatureIntersection(justFPU) - if !testFeatures.IsSubset(justFPU) { - t.Errorf("Got more features than expected after intersecting host features with justFPU: %v, want %v", testFeatures.Set, justFPU.Set) - } - if !testFeatures.HasFeature(X86FeatureFPU) { - t.Errorf("Got no features in testFeatures after intersecting, want %v", X86FeatureFPU) + if justFPUandPAE.Subtract(justFPU) == nil { + t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE) } } @@ -83,7 +72,7 @@ func TestTakeFeatureIntersection(t *testing.T) { // if HostFeatureSet gives back junk bits. func TestHostFeatureSet(t *testing.T) { hostFeatures := HostFeatureSet() - if !justFPUandPAE.IsSubset(hostFeatures) { + if justFPUandPAE.Subtract(hostFeatures) != nil { t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures) } } @@ -175,6 +164,7 @@ func TestEmulateIDBasicFeatures(t *testing.T) { testFeatures := newEmptyFeatureSet() testFeatures.Add(X86FeatureCLFSH) testFeatures.Add(X86FeatureAVX) + testFeatures.CacheLine = 64 ax, bx, cx, dx := testFeatures.EmulateID(1, 0) ECXAVXBit := uint32(1 << uint(X86FeatureAVX)) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 9961baaa9..71f2abc83 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index 785c685a0..c7f549428 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -7,6 +8,9 @@ go_library( srcs = ["fd.go"], importpath = "gvisor.dev/gvisor/pkg/fd", visibility = ["//visibility:public"], + deps = [ + "//pkg/unet", + ], ) go_test( diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go index 83bcfe220..7691b477b 100644 --- a/pkg/fd/fd.go +++ b/pkg/fd/fd.go @@ -22,6 +22,8 @@ import ( "runtime" "sync/atomic" "syscall" + + "gvisor.dev/gvisor/pkg/unet" ) // ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It @@ -185,6 +187,12 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) { return New(f), nil } +// DialUnix connects to a Unix Domain Socket and return the file descriptor. +func DialUnix(path string) (*FD, error) { + socket, err := unet.Connect(path, false) + return New(socket.FD()), err +} + // Close closes the file descriptor contained in the FD. // // Close is safe to call multiple times, but will return an error after the diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index e54e7371c..56495cbd9 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index bd1d614b6..5643d5f26 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -18,6 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", + "//third_party/gvsync", ], ) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index 865b6f640..8390915a2 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error { *ep.dataLen() = w.Len() // Return control to the client. + raceBecomeInactive() if err := ep.futexSwitchToPeer(); err != nil { return err } @@ -121,7 +122,16 @@ func (ep *Endpoint) enterFutexWait() error { } func (ep *Endpoint) exitFutexWait() { - atomic.AddInt32(&ep.ctrl.state, -epsBlocked) + switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps { + case 0: + return + case epsShutdown: + // ep.ctrlShutdown() was called while we were blocked, so we are + // repsonsible for indicating connection shutdown. + ep.shutdownConn() + default: + panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked)) + } } func (ep *Endpoint) ctrlShutdown() { @@ -142,5 +152,25 @@ func (ep *Endpoint) ctrlShutdown() { break } } + } else { + // There is no blocked thread, so we are responsible for indicating + // connection shutdown. + ep.shutdownConn() + } +} + +func (ep *Endpoint) shutdownConn() { + switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs { + case ep.activeState: + if err := ep.futexWakeConnState(1); err != nil { + log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err) + } + case ep.inactiveState: + // The peer is currently active and will detect shutdown when it tries + // to update the connection state. + case csShutdown: + // The peer also called Endpoint.Shutdown(). + default: + log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs) } } diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 5c9212c33..386cee42c 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -42,11 +42,6 @@ type Endpoint struct { // dataCap is immutable. dataCap uint32 - // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the - // Endpoint has acknowledged shutdown initiated by the peer. shutdown is - // accessed using atomic memory operations. - shutdown uint32 - // activeState is csClientActive if this is a client Endpoint and // csServerActive if this is a server Endpoint. activeState uint32 @@ -55,9 +50,27 @@ type Endpoint struct { // csClientActive if this is a server Endpoint. inactiveState uint32 + // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the + // Endpoint has acknowledged shutdown initiated by the peer. shutdown is + // accessed using atomic memory operations. + shutdown uint32 + ctrl endpointControlImpl } +// EndpointSide indicates which side of a connection an Endpoint belongs to. +type EndpointSide int + +const ( + // ClientSide indicates that an Endpoint is a client (initially-active; + // first method call should be Connect). + ClientSide EndpointSide = iota + + // ServerSide indicates that an Endpoint is a server (initially-inactive; + // first method call should be RecvFirst.) + ServerSide +) + // Init must be called on zero-value Endpoints before first use. If it // succeeds, ep.Destroy() must be called once the Endpoint is no longer in use. // @@ -65,7 +78,17 @@ type Endpoint struct { // Endpoint. FD may differ between Endpoints if they are in different // processes, but must represent the same file. The packet window must // initially be filled with zero bytes. -func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error { +func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error { + switch side { + case ClientSide: + ep.activeState = csClientActive + ep.inactiveState = csServerActive + case ServerSide: + ep.activeState = csServerActive + ep.inactiveState = csClientActive + default: + return fmt.Errorf("invalid EndpointSide: %v", side) + } if pwd.Length < pageSize { return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize) } @@ -78,9 +101,6 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err } ep.packet = m ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes) - // These will be overwritten by ep.Connect() for client Endpoints. - ep.activeState = csServerActive - ep.inactiveState = csClientActive if err := ep.ctrlInit(opts...); err != nil { ep.unmapPacket() return err @@ -90,9 +110,9 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err // NewEndpoint is a convenience function that returns an initialized Endpoint // allocated on the heap. -func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) { +func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) { var ep Endpoint - if err := ep.Init(pwd, opts...); err != nil { + if err := ep.Init(side, pwd, opts...); err != nil { return nil, err } return &ep, nil @@ -115,9 +135,9 @@ func (ep *Endpoint) unmapPacket() { } // Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), -// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not -// wait for concurrent calls to return. The effect of Shutdown on the peer -// Endpoint is unspecified. Successive calls to Shutdown have no effect. +// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer +// Endpoint, to unblock and return errors. It does not wait for concurrent +// calls to return. Successive calls to Shutdown have no effect. // // Shutdown is the only Endpoint method that may be called concurrently with // other methods on the same Endpoint. @@ -152,28 +172,31 @@ const ( // The client is, by definition, initially active, so this must be 0. csClientActive = 0 csServerActive = 1 + csShutdown = 2 ) -// Connect designates ep as a client Endpoint and blocks until the peer -// Endpoint has called Endpoint.RecvFirst(). +// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst(). // -// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and -// ep.SendLast() have never been called. +// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), +// ep.SendRecv(), and ep.SendLast() have never been called. func (ep *Endpoint) Connect() error { - ep.activeState = csClientActive - ep.inactiveState = csServerActive - return ep.ctrlConnect() + err := ep.ctrlConnect() + if err == nil { + raceBecomeActive() + } + return err } // RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then // returns the datagram length specified by that call. // -// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never -// been called. +// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and +// ep.SendLast() have never been called. func (ep *Endpoint) RecvFirst() (uint32, error) { if err := ep.ctrlWaitFirst(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -200,9 +223,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then // they can only shoot themselves in the foot. *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlRoundTrip(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -222,6 +247,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) } *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlWakeLast(); err != nil { return err } diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go index edb6a8bef..8d88b845d 100644 --- a/pkg/flipcall/flipcall_example_test.go +++ b/pkg/flipcall/flipcall_example_test.go @@ -38,12 +38,12 @@ func Example() { panic(err) } var clientEP Endpoint - if err := clientEP.Init(pwd); err != nil { + if err := clientEP.Init(ClientSide, pwd); err != nil { panic(err) } defer clientEP.Destroy() var serverEP Endpoint - if err := serverEP.Init(pwd); err != nil { + if err := serverEP.Init(ServerSide, pwd); err != nil { panic(err) } defer serverEP.Destroy() diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index da9d736ab..168a487ec 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -39,11 +39,11 @@ func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []Endpoi c.pwa.Destroy() tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) } - if err := c.clientEP.Init(pwd, clientOpts...); err != nil { + if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil { c.pwa.Destroy() tb.Fatalf("failed to create client Endpoint: %v", err) } - if err := c.serverEP.Init(pwd, serverOpts...); err != nil { + if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil { c.pwa.Destroy() c.clientEP.Destroy() tb.Fatalf("failed to create server Endpoint: %v", err) @@ -62,17 +62,30 @@ func (c *testConnection) destroy() { } func testSendRecv(t *testing.T, c *testConnection) { + // This shared variable is used to confirm that synchronization between + // flipcall endpoints is visible to the Go race detector. + state := 0 var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() t.Logf("server Endpoint waiting for packet 1") if _, err := c.serverEP.RecvFirst(); err != nil { - t.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return + } + state++ + if state != 2 { + t.Errorf("shared state counter: got %d, wanted 2", state) } t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") if _, err := c.serverEP.SendRecv(0); err != nil { - t.Fatalf("server Endpoint.SendRecv() failed: %v", err) + t.Errorf("server Endpoint.SendRecv() failed: %v", err) + return + } + state++ + if state != 4 { + t.Errorf("shared state counter: got %d, wanted 4", state) } t.Logf("server Endpoint got packet 3") }() @@ -87,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) { if err := c.clientEP.Connect(); err != nil { t.Fatalf("client Endpoint.Connect() failed: %v", err) } + state++ + if state != 1 { + t.Errorf("shared state counter: got %d, wanted 1", state) + } t.Logf("client Endpoint sending packet 1 and waiting for packet 2") if _, err := c.clientEP.SendRecv(0); err != nil { t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } + state++ + if state != 3 { + t.Errorf("shared state counter: got %d, wanted 3", state) + } t.Logf("client Endpoint got packet 2, sending packet 3") if err := c.clientEP.SendLast(0); err != nil { t.Fatalf("client Endpoint.SendLast() failed: %v", err) @@ -105,7 +126,30 @@ func TestSendRecv(t *testing.T) { testSendRecv(t, c) } -func testShutdownConnect(t *testing.T, c *testConnection) { +func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + if err := c.clientEP.Connect(); err == nil { + t.Errorf("client Endpoint.Connect() succeeded unexpectedly") + } +} + +func TestShutdownBeforeConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeConnect(t, c, false) +} + +func TestShutdownBeforeConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeConnect(t, c, true) +} + +func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) { var clientRun sync.WaitGroup clientRun.Add(1) go func() { @@ -115,44 +159,86 @@ func testShutdownConnect(t *testing.T, c *testConnection) { } }() time.Sleep(time.Second) // to allow c.clientEP.Connect() to block - c.clientEP.Shutdown() + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } clientRun.Wait() } -func TestShutdownConnect(t *testing.T) { +func TestShutdownDuringConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringConnect(t, c, false) +} + +func TestShutdownDuringConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringConnect(t, c, true) +} + +func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) { + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } + if _, err := c.serverEP.RecvFirst(); err == nil { + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") + } +} + +func TestShutdownBeforeRecvFirstLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeRecvFirst(t, c, false) +} + +func TestShutdownBeforeRecvFirstRemote(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownConnect(t, c) + testShutdownBeforeRecvFirst(t, c, true) } -func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) { +func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() - _, err := c.serverEP.RecvFirst() - if err == nil { + if _, err := c.serverEP.RecvFirst(); err == nil { t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") } }() time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block - c.serverEP.Shutdown() + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } serverRun.Wait() } -func TestShutdownRecvFirstBeforeConnect(t *testing.T) { +func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstBeforeConnect(t, c, false) +} + +func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownRecvFirstBeforeConnect(t, c) + testShutdownDuringRecvFirstBeforeConnect(t, c, true) } -func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) { +func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) { var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() if _, err := c.serverEP.RecvFirst(); err == nil { - t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly") + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") } }() defer func() { @@ -164,23 +250,75 @@ func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) { if err := c.clientEP.Connect(); err != nil { t.Fatalf("client Endpoint.Connect() failed: %v", err) } - c.serverEP.Shutdown() + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } serverRun.Wait() } -func TestShutdownRecvFirstAfterConnect(t *testing.T) { +func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstAfterConnect(t, c, false) +} + +func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstAfterConnect(t, c, true) +} + +func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err != nil { + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + } + // At this point, the client must be blocked in c.clientEP.SendRecv(). + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + }() + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) + } + if _, err := c.clientEP.SendRecv(0); err == nil { + t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly") + } +} + +func TestShutdownDuringClientSendRecvLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringClientSendRecv(t, c, false) +} + +func TestShutdownDuringClientSendRecvRemote(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownRecvFirstAfterConnect(t, c) + testShutdownDuringClientSendRecv(t, c, true) } -func testShutdownSendRecv(t *testing.T, c *testConnection) { +func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() if _, err := c.serverEP.RecvFirst(); err != nil { - t.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return } if _, err := c.serverEP.SendRecv(0); err == nil { t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly") @@ -199,14 +337,24 @@ func testShutdownSendRecv(t *testing.T, c *testConnection) { t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } time.Sleep(time.Second) // to allow serverEP.SendRecv() to block - c.serverEP.Shutdown() + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } serverRun.Wait() } -func TestShutdownSendRecv(t *testing.T) { +func TestShutdownDuringServerSendRecvLocal(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownSendRecv(t, c) + testShutdownDuringServerSendRecv(t, c, false) +} + +func TestShutdownDuringServerSendRecvRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringServerSendRecv(t, c, true) } func benchmarkSendRecv(b *testing.B, c *testConnection) { @@ -218,15 +366,17 @@ func benchmarkSendRecv(b *testing.B, c *testConnection) { return } if _, err := c.serverEP.RecvFirst(); err != nil { - b.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + b.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return } for i := 1; i < b.N; i++ { if _, err := c.serverEP.SendRecv(0); err != nil { - b.Fatalf("server Endpoint.SendRecv() failed: %v", err) + b.Errorf("server Endpoint.SendRecv() failed: %v", err) + return } } if err := c.serverEP.SendLast(0); err != nil { - b.Fatalf("server Endpoint.SendLast() failed: %v", err) + b.Errorf("server Endpoint.SendLast() failed: %v", err) } }() defer func() { diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 7c8977893..a37952637 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -17,17 +17,19 @@ package flipcall import ( "reflect" "unsafe" + + "gvisor.dev/gvisor/third_party/gvsync" ) -// Packets consist of an 8-byte header followed by an arbitrarily-sized +// Packets consist of a 16-byte header followed by an arbitrarily-sized // datagram. The header consists of: // // - A 4-byte native-endian connection state. // // - A 4-byte native-endian datagram length in bytes. +// +// - 8 reserved bytes. const ( - sizeofUint32 = unsafe.Sizeof(uint32(0)) - // PacketHeaderBytes is the size of a flipcall packet header in bytes. The // maximum datagram size supported by a flipcall connection is equal to the // length of the packet window minus PacketHeaderBytes. @@ -35,7 +37,7 @@ const ( // PacketHeaderBytes is exported to support its use in constant // expressions. Non-constant expressions may prefer to use // PacketWindowLengthForDataCap(). - PacketHeaderBytes = 2 * sizeofUint32 + PacketHeaderBytes = 16 ) func (ep *Endpoint) connState() *uint32 { @@ -43,7 +45,7 @@ func (ep *Endpoint) connState() *uint32 { } func (ep *Endpoint) dataLen() *uint32 { - return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32)) + return (*uint32)((unsafe.Pointer)(ep.packet + 4)) } // Data returns the datagram part of ep's packet window as a byte slice. @@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte { bsReflect.Cap = int(ep.dataCap) return bs } + +// ioSync is a dummy variable used to indicate synchronization to the Go race +// detector. Compare syscall.ioSync. +var ioSync int64 + +func raceBecomeActive() { + if gvsync.RaceEnabled { + gvsync.RaceAcquire((unsafe.Pointer)(&ioSync)) + } +} + +func raceBecomeInactive() { + if gvsync.RaceEnabled { + gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + } +} diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index e7dd812b3..b127a2bbb 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -59,7 +59,12 @@ func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeRespon func (ep *Endpoint) futexSwitchToPeer() error { // Update connection state to indicate that the peer should be active. if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { - return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState())) + switch cs := atomic.LoadUint32(ep.connState()); cs { + case csShutdown: + return shutdownError{} + default: + return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) + } } // Wake the peer's Endpoint.futexSwitchFromPeer(). @@ -75,16 +80,18 @@ func (ep *Endpoint) futexSwitchFromPeer() error { case ep.activeState: return nil case ep.inactiveState: - // Continue to FUTEX_WAIT. + if ep.isShutdownLocally() { + return shutdownError{} + } + if err := ep.futexWaitConnState(ep.inactiveState); err != nil { + return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) + } + continue + case csShutdown: + return shutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } - if ep.isShutdownLocally() { - return shutdownError{} - } - if err := ep.futexWaitConnState(ep.inactiveState); err != nil { - return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) - } } } diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index 11716af81..0c5f50397 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index e6a8dbd02..4b9321711 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD index 8f3defa25..34d2673ef 100644 --- a/pkg/ilist/BUILD +++ b/pkg/ilist/BUILD @@ -1,5 +1,6 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index c8e923a74..a5d980d14 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 12615240c..fc5f5779b 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 3b8a691f4..dd6ca6d39 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -1,5 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -21,6 +23,12 @@ proto_library( visibility = ["//:sandbox"], ) +cc_proto_library( + name = "metric_cc_proto", + visibility = ["//:sandbox"], + deps = [":metric_proto"], +) + go_proto_library( name = "metric_go_proto", importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto", diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index c6737bf97..f32244c69 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], @@ -19,11 +20,14 @@ go_library( "pool.go", "server.go", "transport.go", + "transport_flipcall.go", "version.go", ], importpath = "gvisor.dev/gvisor/pkg/p9", deps = [ "//pkg/fd", + "//pkg/fdchannel", + "//pkg/flipcall", "//pkg/log", "//pkg/unet", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 7dc20aeef..221516c6c 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -20,6 +20,8 @@ import ( "sync" "syscall" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -77,6 +79,47 @@ type Client struct { // fidPool is the collection of available fids. fidPool pool + // messageSize is the maximum total size of a message. + messageSize uint32 + + // payloadSize is the maximum payload size of a read or write. + // + // For large reads and writes this means that the read or write is + // broken up into buffer-size/payloadSize requests. + payloadSize uint32 + + // version is the agreed upon version X of 9P2000.L.Google.X. + // version 0 implies 9P2000.L. + version uint32 + + // closedWg is marked as done when the Client.watch() goroutine, which is + // responsible for closing channels and the socket fd, returns. + closedWg sync.WaitGroup + + // sendRecv is the transport function. + // + // This is determined dynamically based on whether or not the server + // supports flipcall channels (preferred as it is faster and more + // efficient, and does not require tags). + sendRecv func(message, message) error + + // -- below corresponds to sendRecvChannel -- + + // channelsMu protects channels. + channelsMu sync.Mutex + + // channelsWg counts the number of channels for which channel.active == + // true. + channelsWg sync.WaitGroup + + // channels is the set of all initialized channels. + channels []*channel + + // availableChannels is a FIFO of inactive channels. + availableChannels []*channel + + // -- below corresponds to sendRecvLegacy -- + // pending is the set of pending messages. pending map[Tag]*response pendingMu sync.Mutex @@ -89,25 +132,12 @@ type Client struct { // Whoever writes to this channel is permitted to call recv. When // finished calling recv, this channel should be emptied. recvr chan bool - - // messageSize is the maximum total size of a message. - messageSize uint32 - - // payloadSize is the maximum payload size of a read or write - // request. For large reads and writes this means that the - // read or write is broken up into buffer-size/payloadSize - // requests. - payloadSize uint32 - - // version is the agreed upon version X of 9P2000.L.Google.X. - // version 0 implies 9P2000.L. - version uint32 } // NewClient creates a new client. It performs a Tversion exchange with // the server to assert that messageSize is ok to use. // -// You should not use the same socket for multiple clients. +// If NewClient succeeds, ownership of socket is transferred to the new Client. func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { // Need at least one byte of payload. if messageSize <= msgRegistry.largestFixedSize { @@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client return nil, ErrBadVersionString } for { + // Always exchange the version using the legacy version of the + // protocol. If the protocol supports flipcall, then we switch + // our sendRecv function to use that functionality. Otherwise, + // we stick to sendRecvLegacy. rversion := Rversion{} - err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion) + err := c.sendRecvLegacy(&Tversion{ + Version: versionString(requested), + MSize: messageSize, + }, &rversion) // The server told us to try again with a lower version. if err == syscall.EAGAIN { @@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client c.version = version break } + + // Can we switch to use the more advanced channels and create + // independent channels for communication? Prefer it if possible. + if versionSupportsFlipcall(c.version) { + // Attempt to initialize IPC-based communication. + for i := 0; i < channelsPerClient; i++ { + if err := c.openChannel(i); err != nil { + log.Warningf("error opening flipcall channel: %v", err) + break // Stop. + } + } + if len(c.channels) >= 1 { + // At least one channel created. + c.sendRecv = c.sendRecvChannel + } else { + // Channel setup failed; fallback. + c.sendRecv = c.sendRecvLegacy + } + } else { + // No channels available: use the legacy mechanism. + c.sendRecv = c.sendRecvLegacy + } + + // Ensure that the socket and channels are closed when the socket is shut + // down. + c.closedWg.Add(1) + go c.watch(socket) // S/R-SAFE: not relevant. + return c, nil } +// watch watches the given socket and releases resources on hangup events. +// +// This is intended to be called as a goroutine. +func (c *Client) watch(socket *unet.Socket) { + defer c.closedWg.Done() + + events := []unix.PollFd{ + unix.PollFd{ + Fd: int32(socket.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == syscall.EINTR || err == syscall.EAGAIN { + continue + } + if err != nil { + log.Warningf("p9.Client.watch(): %v", err) + break + } + if n != 1 { + log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Set availableChannels to nil so that future calls to c.sendRecvChannel() + // don't attempt to activate a channel, and concurrent calls to + // c.sendRecvChannel() don't mark released channels as available. + c.channelsMu.Lock() + c.availableChannels = nil + + // Shut down all active channels. + for _, ch := range c.channels { + if ch.active { + log.Debugf("shutting down active channel@%p...", ch) + ch.Shutdown() + } + } + c.channelsMu.Unlock() + + // Wait for active channels to become inactive. + c.channelsWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.Close() + } + c.channelsMu.Unlock() + + // Close the main socket. + c.socket.Close() +} + +// openChannel attempts to open a client channel. +// +// Note that this function returns naked errors which should not be propagated +// directly to a caller. It is expected that the errors will be logged and a +// fallback path will be used instead. +func (c *Client) openChannel(id int) error { + var ( + rchannel0 Rchannel + rchannel1 Rchannel + res = new(channel) + ) + + // Open the data channel. + if err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 0, + }, &rchannel0); err != nil { + return fmt.Errorf("error handling Tchannel message: %v", err) + } + if rchannel0.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on primary channel") + } + + // We don't need to hold this. + defer rchannel0.FilePayload().Close() + + // Open the channel for file descriptors. + if err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 1, + }, &rchannel1); err != nil { + return err + } + if rchannel1.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on file descriptor channel") + } + + // Construct the endpoints. + res.desc = flipcall.PacketWindowDescriptor{ + FD: rchannel0.FilePayload().FD(), + Offset: int64(rchannel0.Offset), + Length: int(rchannel0.Length), + } + if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil { + rchannel1.FilePayload().Close() + return err + } + + // The fds channel owns the control payload, and it will be closed when + // the channel object is closed. + res.fds.Init(rchannel1.FilePayload().Release()) + + // Save the channel. + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + c.channels = append(c.channels, res) + c.availableChannels = append(c.availableChannels, res) + return nil +} + // handleOne handles a single incoming message. // // This should only be called with the token from recvr. Note that the received @@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error { } } -// sendRecv performs a roundtrip message exchange. +// sendRecvLegacy performs a roundtrip message exchange. // // This is called by internal functions. -func (c *Client) sendRecv(t message, r message) error { +func (c *Client) sendRecvLegacy(t message, r message) error { tag, ok := c.tagPool.Get() if !ok { return ErrOutOfTags @@ -296,12 +479,77 @@ func (c *Client) sendRecv(t message, r message) error { return nil } +// sendRecvChannel uses channels to send a message. +func (c *Client) sendRecvChannel(t message, r message) error { + // Acquire an available channel. + c.channelsMu.Lock() + if len(c.availableChannels) == 0 { + c.channelsMu.Unlock() + return c.sendRecvLegacy(t, r) + } + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + ch.active = true + c.channelsWg.Add(1) + c.channelsMu.Unlock() + + // Ensure that it's connected. + if !ch.connected { + ch.connected = true + if err := ch.data.Connect(); err != nil { + // The channel is unusable, so don't return it to + // c.availableChannels. However, we still have to mark it as + // inactive so c.watch() doesn't wait for it. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + // Map all transport errors to EIO, but ensure that the real error + // is logged. + log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err) + return syscall.EIO + } + } + + // Send the request and receive the server's response. + rsz, err := ch.send(t) + if err != nil { + // See above. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err) + return syscall.EIO + } + + // Parse the server's response. + _, retErr := ch.recv(r, rsz) + + // Release the channel. + c.channelsMu.Lock() + ch.active = false + // If c.availableChannels is nil, c.watch() has fired and we should not + // mark this channel as available. + if c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.channelsMu.Unlock() + c.channelsWg.Done() + + return retErr +} + // Version returns the negotiated 9P2000.L.Google version number. func (c *Client) Version() uint32 { return c.version } -// Close closes the underlying socket. -func (c *Client) Close() error { - return c.socket.Close() +// Close closes the underlying socket and channels. +func (c *Client) Close() { + // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already + // been called (by c.watch()). + c.socket.Shutdown() + c.closedWg.Wait() } diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index 87b2dd61e..29a0afadf 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -35,23 +35,23 @@ func TestVersion(t *testing.T) { go s.Handle(serverSocket) // NewClient does a Tversion exchange, so this is our test for success. - c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString()) + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) if err != nil { t.Fatalf("got %v, expected nil", err) } // Check a bogus version string. - if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { + if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { t.Errorf("got %v expected %v", err, syscall.EINVAL) } // Check a bogus version number. - if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { + if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { t.Errorf("got %v expected %v", err, syscall.EINVAL) } // Check a too high version number. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN { + if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN { t.Errorf("got %v expected %v", err, syscall.EAGAIN) } @@ -60,3 +60,45 @@ func TestVersion(t *testing.T) { t.Errorf("got %v expected %v", err, syscall.EINVAL) } } + +func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { + // See above. + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + b.Fatalf("socketpair got err %v expected nil", err) + } + defer clientSocket.Close() + + // See above. + s := NewServer(nil) + go s.Handle(serverSocket) + + // See above. + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) + if err != nil { + b.Fatalf("got %v, expected nil", err) + } + + // Initialize messages. + sendRecv := fn(c) + tversion := &Tversion{ + Version: versionString(highestSupportedVersion), + MSize: DefaultMessageSize, + } + rversion := new(Rversion) + + // Run in a loop. + for i := 0; i < b.N; i++ { + if err := sendRecv(tversion, rversion); err != nil { + b.Fatalf("got unexpected err: %v", err) + } + } +} + +func BenchmarkSendRecvLegacy(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy }) +} + +func BenchmarkSendRecvChannel(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel }) +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 999b4f684..ba9a55d6d 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -305,7 +305,9 @@ func (t *Tlopen) handle(cs *connState) message { ref.opened = true ref.openFlags = t.Flags - return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile} + rlopen := &Rlopen{QID: qid, IoUnit: ioUnit} + rlopen.SetFilePayload(osFile) + return rlopen } func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { @@ -364,7 +366,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { // Replace the FID reference. cs.InsertFID(t.FID, newRef) - return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil + rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}} + rlcreate.SetFilePayload(osFile) + return rlcreate, nil } // handle implements handler.handle. @@ -1287,5 +1291,48 @@ func (t *Tlconnect) handle(cs *connState) message { return newErr(err) } - return &Rlconnect{File: osFile} + rlconnect := &Rlconnect{} + rlconnect.SetFilePayload(osFile) + return rlconnect +} + +// handle implements handler.handle. +func (t *Tchannel) handle(cs *connState) message { + // Ensure that channels are enabled. + if err := cs.initializeChannels(); err != nil { + return newErr(err) + } + + // Lookup the given channel. + ch := cs.lookupChannel(t.ID) + if ch == nil { + return newErr(syscall.ENOSYS) + } + + // Return the payload. Note that we need to duplicate the file + // descriptor for the channel allocator, because sending is a + // destructive operation between sendRecvLegacy (and now the newer + // channel send operations). Same goes for the client FD. + rchannel := &Rchannel{ + Offset: uint64(ch.desc.Offset), + Length: uint64(ch.desc.Length), + } + switch t.Control { + case 0: + // Open the main data channel. + mfd, err := syscall.Dup(int(cs.channelAlloc.FD())) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(mfd)) + case 1: + cfd, err := syscall.Dup(ch.client.FD()) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(cfd)) + default: + return newErr(syscall.EINVAL) + } + return rchannel } diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index fd9eb1c5d..ffdd7e8c6 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -64,6 +64,21 @@ type filer interface { SetFilePayload(*fd.FD) } +// filePayload embeds a File object. +type filePayload struct { + File *fd.FD +} + +// FilePayload returns the file payload. +func (f *filePayload) FilePayload() *fd.FD { + return f.File +} + +// SetFilePayload sets the received file. +func (f *filePayload) SetFilePayload(file *fd.FD) { + f.File = file +} + // Tversion is a version request. type Tversion struct { // MSize is the message size to use. @@ -524,10 +539,7 @@ type Rlopen struct { // IoUnit is the recommended I/O unit. IoUnit uint32 - // File may be attached via the socket. - // - // This is an extension specific to this package. - File *fd.FD + filePayload } // Decode implements encoder.Decode. @@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType { return MsgRlopen } -// FilePayload returns the file payload. -func (r *Rlopen) FilePayload() *fd.FD { - return r.File -} - -// SetFilePayload sets the received file. -func (r *Rlopen) SetFilePayload(file *fd.FD) { - r.File = file -} - // String implements fmt.Stringer. func (r *Rlopen) String() string { return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File) @@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string { // Rlconnect is a connect response. type Rlconnect struct { - // File is a host socket. - File *fd.FD + filePayload } // Decode implements encoder.Decode. @@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType { return MsgRlconnect } -// FilePayload returns the file payload. -func (r *Rlconnect) FilePayload() *fd.FD { - return r.File +// String implements fmt.Stringer. +func (r *Rlconnect) String() string { + return fmt.Sprintf("Rlconnect{File: %v}", r.File) } -// SetFilePayload sets the received file. -func (r *Rlconnect) SetFilePayload(file *fd.FD) { - r.File = file +// Tchannel creates a new channel. +type Tchannel struct { + // ID is the channel ID. + ID uint32 + + // Control is 0 if the Rchannel response should provide the flipcall + // component of the channel, and 1 if the Rchannel response should + // provide the fdchannel component of the channel. + Control uint32 +} + +// Decode implements encoder.Decode. +func (t *Tchannel) Decode(b *buffer) { + t.ID = b.Read32() + t.Control = b.Read32() +} + +// Encode implements encoder.Encode. +func (t *Tchannel) Encode(b *buffer) { + b.Write32(t.ID) + b.Write32(t.Control) +} + +// Type implements message.Type. +func (*Tchannel) Type() MsgType { + return MsgTchannel } // String implements fmt.Stringer. -func (r *Rlconnect) String() string { - return fmt.Sprintf("Rlconnect{File: %v}", r.File) +func (t *Tchannel) String() string { + return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control) +} + +// Rchannel is the channel response. +type Rchannel struct { + Offset uint64 + Length uint64 + filePayload +} + +// Decode implements encoder.Decode. +func (r *Rchannel) Decode(b *buffer) { + r.Offset = b.Read64() + r.Length = b.Read64() +} + +// Encode implements encoder.Encode. +func (r *Rchannel) Encode(b *buffer) { + b.Write64(r.Offset) + b.Write64(r.Length) +} + +// Type implements message.Type. +func (*Rchannel) Type() MsgType { + return MsgRchannel +} + +// String implements fmt.Stringer. +func (r *Rchannel) String() string { + return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) } const maxCacheSize = 3 @@ -2356,4 +2409,6 @@ func init() { msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} }) msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) + msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) + msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index e12831dbd..25530adca 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -378,6 +378,8 @@ const ( MsgRlconnect = 137 MsgTallocate = 138 MsgRallocate = 139 + MsgTchannel = 250 + MsgRchannel = 251 ) // QIDType represents the file type for QIDs. diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 6e939a49a..28707c0ca 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") package(licenses = ["notice"]) @@ -77,7 +77,7 @@ go_library( go_test( name = "client_test", - size = "small", + size = "medium", srcs = ["client_test.go"], embed = [":p9test"], deps = [ diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index fe649c2e8..8bbdb2488 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) { } } } + +func TestReadWriteConcurrent(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + const ( + instances = 10 + iterations = 10000 + dataSize = 1024 + ) + var ( + dataSets [instances][dataSize]byte + backends [instances]*Mock + files [instances]p9.File + ) + + // Walk to the file normally. + for i := 0; i < instances; i++ { + _, backends[i], files[i] = walkHelper(h, "file", root) + defer files[i].Close() + } + + // Open the files. + for i := 0; i < instances; i++ { + backends[i].EXPECT().Open(p9.ReadWrite) + if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil { + t.Fatalf("open got %v, wanted nil", err) + } + } + + // Initialize random data for each instance. + for i := 0; i < instances; i++ { + if _, err := rand.Read(dataSets[i][:]); err != nil { + t.Fatalf("error initializing dataSet#%d, got %v", i, err) + } + } + + // Define our random read/write mechanism. + randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) { + // Prepare the backend. + backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if n := copy(p, data); n != len(data) { + // Note that we have to assert the result here, as the Return statement + // below cannot be dynamic: it will be bound before this call is made. + h.t.Errorf("wanted length %d, got %d", len(data), n) + } + }).Return(len(data), nil) + + // Execute the read. + if n, err := f.ReadAt(test, 0); n != len(test) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err) + return // No sense doing check below. + } + if !bytes.Equal(test, data) { + t.Errorf("data integrity failed during read") // Not as expected. + } + } + randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) { + // Prepare the backend. + backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if !bytes.Equal(p, data) { + h.t.Errorf("data integrity failed during write") // Not as expected. + } + }).Return(len(data), nil) + + // Execute the write. + if n, err := f.WriteAt(data, 0); n != len(data) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err) + } + } + randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) { + test := make([]byte, len(data)) + for i := 0; i < n; i++ { + if rand.Intn(2) == 0 { + randRead(h, backend, f, data, test) + } else { + randWrite(h, backend, f, data) + } + } + } + + // Start reading and writing. + var wg sync.WaitGroup + for i := 0; i < instances; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:]) + }(i) + } + wg.Wait() +} diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 95846e5f7..4d3271b37 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator { // Finish completes all checks and shuts down the server. func (h *Harness) Finish() { - h.clientSocket.Close() + h.clientSocket.Shutdown() h.wg.Wait() h.mockCtrl.Finish() } @@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) { }() // Create the client. - client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString()) + client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString()) if err != nil { serverSocket.Close() clientSocket.Close() diff --git a/pkg/p9/server.go b/pkg/p9/server.go index b294efbb0..e717e6161 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -21,6 +21,9 @@ import ( "sync/atomic" "syscall" + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -45,7 +48,6 @@ type Server struct { } // NewServer returns a new server. -// func NewServer(attacher Attacher) *Server { return &Server{ attacher: attacher, @@ -85,6 +87,8 @@ type connState struct { // version 0 implies 9P2000.L. version uint32 + // -- below relates to the legacy handler -- + // recvOkay indicates that a receive may start. recvOkay chan bool @@ -93,6 +97,20 @@ type connState struct { // sendDone is signalled when a send is finished. sendDone chan error + + // -- below relates to the flipcall handler -- + + // channelMu protects below. + channelMu sync.Mutex + + // channelWg represents active workers. + channelWg sync.WaitGroup + + // channelAlloc allocates channel memory. + channelAlloc *flipcall.PacketWindowAllocator + + // channels are the set of initialized channels. + channels []*channel } // fidRef wraps a node and tracks references. @@ -386,6 +404,101 @@ func (cs *connState) WaitTag(t Tag) { <-ch } +// initializeChannels initializes all channels. +// +// This is a no-op if channels are already initialized. +func (cs *connState) initializeChannels() (err error) { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + + // Initialize our channel allocator. + if cs.channelAlloc == nil { + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return err + } + cs.channelAlloc = alloc + } + + // Create all the channels. + for len(cs.channels) < channelsPerClient { + res := &channel{ + done: make(chan struct{}), + } + + res.desc, err = cs.channelAlloc.Allocate(channelSize) + if err != nil { + return err + } + if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil { + return err + } + + socks, err := fdchannel.NewConnectedSockets() + if err != nil { + res.data.Destroy() // Cleanup. + return err + } + res.fds.Init(socks[0]) + res.client = fd.New(socks[1]) + + cs.channels = append(cs.channels, res) + + // Start servicing the channel. + // + // When we call stop, we will close all the channels and these + // routines should finish. We need the wait group to ensure + // that active handlers are actually finished before cleanup. + cs.channelWg.Add(1) + go func() { // S/R-SAFE: Server side. + defer cs.channelWg.Done() + if err := res.service(cs); err != nil { + log.Warningf("p9.channel.service: %v", err) + } + }() + } + + return nil +} + +// lookupChannel looks up the channel with given id. +// +// The function returns nil if no such channel is available. +func (cs *connState) lookupChannel(id uint32) *channel { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + if id >= uint32(len(cs.channels)) { + return nil + } + return cs.channels[id] +} + +// handle handles a single message. +func (cs *connState) handle(m message) (r message) { + defer func() { + if r == nil { + // Don't allow a panic to propagate. + recover() + + // Include a useful log message. + log.Warningf("panic in handler: %s", debug.Stack()) + + // Wrap in an EFAULT error; we don't really have a + // better way to describe this kind of error. It will + // usually manifest as a result of the test framework. + r = newErr(syscall.EFAULT) + } + }() + if handler, ok := m.(handler); ok { + // Call the message handler. + r = handler.handle(cs) + } else { + // Produce an ENOSYS error. + r = newErr(syscall.ENOSYS) + } + return +} + // handleRequest handles a single request. // // The recvDone channel is signaled when recv is done (with a error if @@ -428,41 +541,20 @@ func (cs *connState) handleRequest() { } // Handle the message. - var r message // r is the response. - defer func() { - if r == nil { - // Don't allow a panic to propagate. - recover() - - // Include a useful log message. - log.Warningf("panic in handler: %s", debug.Stack()) + r := cs.handle(m) - // Wrap in an EFAULT error; we don't really have a - // better way to describe this kind of error. It will - // usually manifest as a result of the test framework. - r = newErr(syscall.EFAULT) - } + // Clear the tag before sending. That's because as soon as this hits + // the wire, the client can legally send the same tag. + cs.ClearTag(tag) - // Clear the tag before sending. That's because as soon as this - // hits the wire, the client can legally send another message - // with the same tag. - cs.ClearTag(tag) + // Send back the result. + cs.sendMu.Lock() + err = send(cs.conn, tag, r) + cs.sendMu.Unlock() + cs.sendDone <- err - // Send back the result. - cs.sendMu.Lock() - err = send(cs.conn, tag, r) - cs.sendMu.Unlock() - cs.sendDone <- err - }() - if handler, ok := m.(handler); ok { - // Call the message handler. - r = handler.handle(cs) - } else { - // Produce an ENOSYS error. - r = newErr(syscall.ENOSYS) - } + // Return the message to the cache. msgRegistry.put(m) - m = nil // 'm' should not be touched after this point. } func (cs *connState) handleRequests() { @@ -477,7 +569,27 @@ func (cs *connState) stop() { close(cs.recvDone) close(cs.sendDone) - for _, fidRef := range cs.fids { + // Free the channels. + cs.channelMu.Lock() + for _, ch := range cs.channels { + ch.Shutdown() + } + cs.channelWg.Wait() + for _, ch := range cs.channels { + ch.Close() + } + cs.channels = nil // Clear. + cs.channelMu.Unlock() + + // Free the channel memory. + if cs.channelAlloc != nil { + cs.channelAlloc.Destroy() + } + + // Close all remaining fids. + for fid, fidRef := range cs.fids { + delete(cs.fids, fid) + // Drop final reference in the FID table. Note this should // always close the file, since we've ensured that there are no // handlers running via the wait for Pending => 0 below. @@ -510,7 +622,7 @@ func (cs *connState) service() error { for i := 0; i < pending; i++ { <-cs.sendDone } - return err + return nil } // This handler is now pending. diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 5648df589..6e8b4bbcd 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -54,7 +54,10 @@ const ( headerLength uint32 = 7 // maximumLength is the largest possible message. - maximumLength uint32 = 4 * 1024 * 1024 + maximumLength uint32 = 1 << 20 + + // DefaultMessageSize is a sensible default. + DefaultMessageSize uint32 = 64 << 10 // initialBufferLength is the initial data buffer we allocate. initialBufferLength uint32 = 64 diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go new file mode 100644 index 000000000..233f825e3 --- /dev/null +++ b/pkg/p9/transport_flipcall.go @@ -0,0 +1,243 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "runtime" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +// channelsPerClient is the number of channels to create per client. +// +// While the client and server will generally agree on this number, in reality +// it's completely up to the server. We simply define a minimum of 2, and a +// maximum of 4, and select the number of available processes as a tie-breaker. +// Note that we don't want the number of channels to be too large, because each +// will account for channelSize memory used, which can be large. +var channelsPerClient = func() int { + n := runtime.NumCPU() + if n < 2 { + return 2 + } + if n > 4 { + return 4 + } + return n +}() + +// channelSize is the channel size to create. +// +// We simply ensure that this is larger than the largest possible message size, +// plus the flipcall packet header, plus the two bytes we write below. +const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength) + +// channel is a fast IPC channel. +// +// The same object is used by both the server and client implementations. In +// general, the client will use only the send and recv methods. +type channel struct { + desc flipcall.PacketWindowDescriptor + data flipcall.Endpoint + fds fdchannel.Endpoint + buf buffer + + // -- client only -- + connected bool + active bool + + // -- server only -- + client *fd.FD + done chan struct{} +} + +// reset resets the channel buffer. +func (ch *channel) reset(sz uint32) { + ch.buf.data = ch.data.Data()[:sz] +} + +// service services the channel. +func (ch *channel) service(cs *connState) error { + rsz, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rsz > 0 { + m, err := ch.recv(nil, rsz) + if err != nil { + return err + } + r := cs.handle(m) + msgRegistry.put(m) + rsz, err = ch.send(r) + if err != nil { + return err + } + } + return nil // Done. +} + +// Shutdown shuts down the channel. +// +// This must be called before Close. +func (ch *channel) Shutdown() { + ch.data.Shutdown() +} + +// Close closes the channel. +// +// This must only be called once, and cannot return an error. Note that +// synchronization for this method is provided at a high-level, depending on +// whether it is the client or server. This cannot be called while there are +// active callers in either service or sendRecv. +// +// Precondition: the channel should be shutdown. +func (ch *channel) Close() error { + // Close all backing transports. + ch.fds.Destroy() + ch.data.Destroy() + if ch.client != nil { + ch.client.Close() + } + return nil +} + +// send sends the given message. +// +// The return value is the size of the received response. Not that in the +// server case, this is the size of the next request. +func (ch *channel) send(m message) (uint32, error) { + if log.IsLogging(log.Debug) { + log.Debugf("send [channel @%p] %s", ch, m.String()) + } + + // Send any file payload. + sentFD := false + if filer, ok := m.(filer); ok { + if f := filer.FilePayload(); f != nil { + if err := ch.fds.SendFD(f.FD()); err != nil { + return 0, err + } + f.Close() // Per sendRecvLegacy. + sentFD = true // To mark below. + } + } + + // Encode the message. + // + // Note that IPC itself encodes the length of messages, so we don't + // need to encode a standard 9P header. We write only the message type. + ch.reset(0) + + ch.buf.WriteMsgType(m.Type()) + if sentFD { + ch.buf.Write8(1) // Incoming FD. + } else { + ch.buf.Write8(0) // No incoming FD. + } + m.Encode(&ch.buf) + ssz := uint32(len(ch.buf.data)) // Updated below. + + // Is there a payload? + if payloader, ok := m.(payloader); ok { + p := payloader.Payload() + copy(ch.data.Data()[ssz:], p) + ssz += uint32(len(p)) + } + + // Perform the one-shot communication. + return ch.data.SendRecv(ssz) +} + +// recv decodes a message that exists on the channel. +// +// If the passed r is non-nil, then the type must match or an error will be +// generated. If the passed r is nil, then a new message will be created and +// returned. +func (ch *channel) recv(r message, rsz uint32) (message, error) { + // Decode the response from the inline buffer. + ch.reset(rsz) + t := ch.buf.ReadMsgType() + hasFD := ch.buf.Read8() != 0 + if t == MsgRlerror { + // Change the message type. We check for this special case + // after decoding below, and transform into an error. + r = &Rlerror{} + } else if r == nil { + nr, err := msgRegistry.get(0, t) + if err != nil { + return nil, err + } + r = nr // New message. + } else if t != r.Type() { + // Not an error and not the expected response; propagate. + return nil, &ErrBadResponse{Got: t, Want: r.Type()} + } + + // Is there a payload? Copy from the latter portion. + if payloader, ok := r.(payloader); ok { + fs := payloader.FixedSize() + p := payloader.Payload() + payloadData := ch.buf.data[fs:] + if len(p) < len(payloadData) { + p = make([]byte, len(payloadData)) + copy(p, payloadData) + payloader.SetPayload(p) + } else if n := copy(p, payloadData); n < len(p) { + payloader.SetPayload(p[:n]) + } + ch.buf.data = ch.buf.data[:fs] + } + + r.Decode(&ch.buf) + if ch.buf.isOverrun() { + // Nothing valid was available. + log.Debugf("recv [got %d bytes, needed more]", rsz) + return nil, ErrNoValidMessage + } + + // Read any FD result. + if hasFD { + if rfd, err := ch.fds.RecvFDNonblock(); err == nil { + f := fd.New(rfd) + if filer, ok := r.(filer); ok { + // Set the payload. + filer.SetFilePayload(f) + } else { + // Don't want the FD. + f.Close() + } + } else { + // The header bit was set but nothing came in. + log.Warningf("expected FD, got err: %v", err) + } + } + + // Log a message. + if log.IsLogging(log.Debug) { + log.Debugf("recv [channel @%p] %s", ch, r.String()) + } + + // Convert errors appropriately; see above. + if rlerr, ok := r.(*Rlerror); ok { + return nil, syscall.Errno(rlerr.Error) + } + + return r, nil +} diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index cdb3bc841..2f50ff3ea 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) { t.Fatalf("unable to create file: %v", err) } - if err := send(client, Tag(1), &Rlopen{File: f}); err != nil { + rlopen := &Rlopen{} + rlopen.SetFilePayload(f) + if err := send(client, Tag(1), rlopen); err != nil { t.Fatalf("send got err %v expected nil", err) } diff --git a/pkg/p9/version.go b/pkg/p9/version.go index c2a2885ae..f1ffdd23a 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 7 + highestSupportedVersion uint32 = 8 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -148,3 +148,10 @@ func VersionSupportsMultiUser(v uint32) bool { func versionSupportsTallocate(v uint32) bool { return v >= 7 } + +// versionSupportsFlipcall returns true if version v supports IPC channels from +// the flipcall package. Note that these must be negotiated, but this version +// string indicates that such a facility exists. +func versionSupportsFlipcall(v uint32) bool { + return v >= 8 +} diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 697e7a2f4..078f084b2 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 9c08452fc..827385139 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "weak_ref_list", diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 828e9b5c1..ad69e0757 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -215,8 +215,8 @@ type AtomicRefCount struct { type LeakMode uint32 const ( - // uninitializedLeakChecking indicates that the leak checker has not yet been initialized. - uninitializedLeakChecking LeakMode = iota + // UninitializedLeakChecking indicates that the leak checker has not yet been initialized. + UninitializedLeakChecking LeakMode = iota // NoLeakChecking indicates that no effort should be made to check for // leaks. @@ -318,7 +318,7 @@ func (r *AtomicRefCount) finalize() { switch LeakMode(atomic.LoadUint32(&leakMode)) { case NoLeakChecking: return - case uninitializedLeakChecking: + case UninitializedLeakChecking: note = "(Leak checker uninitialized): " } if n := r.ReadRefs(); n != 0 { diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index d1024e49d..af94e944d 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test") package(licenses = ["notice"]) diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go index 0a3d92854..be328db12 100644 --- a/pkg/seccomp/seccomp_unsafe.go +++ b/pkg/seccomp/seccomp_unsafe.go @@ -35,7 +35,7 @@ type sockFprog struct { //go:nosplit func SetFilter(instrs []linux.BPFInstruction) syscall.Errno { // PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details. - if _, _, errno := syscall.RawSyscall(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0, 0); errno != 0 { return errno } diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD index f38fb39f3..22abdc69f 100644 --- a/pkg/secio/BUILD +++ b/pkg/secio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index 694486296..12d7c77d2 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:private"], diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD index 53989301f..2d6379c86 100644 --- a/pkg/sentry/BUILD +++ b/pkg/sentry/BUILD @@ -8,5 +8,7 @@ package_group( packages = [ "//pkg/sentry/...", "//runsc/...", + # Code generated by go_marshal relies on go_marshal libraries. + "//tools/go_marshal/...", ], ) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 7aace2d7b..c71cff9f3 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -1,4 +1,5 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -42,6 +43,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "registers_cc_proto", + visibility = ["//visibility:public"], + deps = [":registers_proto"], +) + go_proto_library( name = "registers_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto", diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index bf802d1b6..5522cecd0 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 7e8918722..0c86197f7 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "device", diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index d7259b47b..3119a61b6 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "fs", diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index fbca06761..3cb73bd78 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -1126,7 +1126,7 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error { // Remove removes the given file or symlink. The root dirent is used to // resolve name, and must not be nil. -func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error { +func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error { // Check the root. if root == nil { panic("Dirent.Remove: root must not be nil") @@ -1151,6 +1151,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error { // Remove cannot remove directories. if IsDir(child.Inode.StableAttr) { return syscall.EISDIR + } else if dirPath { + return syscall.ENOTDIR } // Remove cannot remove a mount point. diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go index 884e3ff06..47bc72a88 100644 --- a/pkg/sentry/fs/dirent_refs_test.go +++ b/pkg/sentry/fs/dirent_refs_test.go @@ -343,7 +343,7 @@ func TestRemoveExtraRefs(t *testing.T) { } d := f.Dirent - if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil { + if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil { t.Fatalf("root.Remove(root, %q) failed: %v", name, err) } diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index bf00b9c09..b9bd9ed17 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "fdpipe", diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index bb8117f89..c0a6e884b 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -515,6 +515,11 @@ type lockedReader struct { // File is the file to read from. File *File + + // Offset is the offset to start at. + // + // This applies only to Read, not ReadAt. + Offset int64 } // Read implements io.Reader.Read. @@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) { if r.Ctx.Interrupted() { return 0, syserror.ErrInterrupted } - n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset) + n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset) + r.Offset += n return int(n), err } @@ -544,11 +550,21 @@ type lockedWriter struct { // File is the file to write to. File *File + + // Offset is the offset to start at. + // + // This applies only to Write, not WriteAt. + Offset int64 } // Write implements io.Writer.Write. func (w *lockedWriter) Write(buf []byte) (int, error) { - return w.WriteAt(buf, w.File.offset) + if w.Ctx.Interrupted() { + return 0, syserror.ErrInterrupted + } + n, err := w.WriteAt(buf, w.Offset) + w.Offset += int64(n) + return int(n), err } // WriteAt implements io.Writer.WriteAt. @@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { // io.Copy, since our own Write interface does not have this same // contract. Enforce that here. for written < len(buf) { + if w.Ctx.Interrupted() { + return written, syserror.ErrInterrupted + } var n int64 n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written)) if n > 0 { diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index d86f5bf45..b88303f17 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -15,6 +15,8 @@ package fs import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -105,8 +107,11 @@ type FileOperations interface { // on the destination, following by a buffered copy with standard Read // and Write operations. // + // If dup is set, the data should be duplicated into the destination + // and retained. + // // The same preconditions as Read apply. - WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error) + WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error) // Write writes src to file at offset and returns the number of bytes // written which must be greater than or equal to 0. Like Read, file @@ -126,7 +131,7 @@ type FileOperations interface { // source. See WriteTo for details regarding how this is called. // // The same preconditions as Write apply; FileFlags.Write must be set. - ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error) + ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error) // Fsync writes buffered modifications of file and/or flushes in-flight // operations to backing storage based on syncType. The range to sync is diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 9820f0b13..225e40186 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "gvisor.dev/gvisor/pkg/refs" @@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme } // WriteTo implements FileOperations.WriteTo. -func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) { err = f.onTop(ctx, file, func(file *File, ops FileOperations) error { - n, err = ops.WriteTo(ctx, file, dst, opts) + n, err = ops.WriteTo(ctx, file, dst, count, dup) return err // Will overwrite itself. }) return @@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm } // ReadFrom implements FileOperations.ReadFrom. -func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) { // See above; f.upper must be non-nil. - return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts) + return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count) } // Fsync implements FileOperations.Fsync. diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 6499f87ac..b4ac83dc4 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "dirty_set_impl", diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 626b9126a..fc5b3b1a1 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -15,6 +15,8 @@ package fsutil import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu type FileNoSplice struct{} // WriteTo implements fs.FileOperations.WriteTo. -func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } // ReadFrom implements fs.FileOperations.ReadFrom. -func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index d2495cb83..693625ddc 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -144,7 +144,7 @@ func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error { mask := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: newSize} - if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil { + if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil { return err } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index e70bc28fb..dd80757dc 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -66,10 +66,8 @@ type CachingInodeOperations struct { // mfp is used to allocate memory that caches backingFile's contents. mfp pgalloc.MemoryFileProvider - // forcePageCache indicates the sentry page cache should be used regardless - // of whether the platform supports host mapped I/O or not. This must not be - // modified after inode creation. - forcePageCache bool + // opts contains options. opts is immutable. + opts CachingInodeOperationsOptions attrMu sync.Mutex `state:"nosave"` @@ -116,6 +114,20 @@ type CachingInodeOperations struct { refs frameRefSet } +// CachingInodeOperationsOptions configures a CachingInodeOperations. +// +// +stateify savable +type CachingInodeOperationsOptions struct { + // If ForcePageCache is true, use the sentry page cache even if a host file + // descriptor is available. + ForcePageCache bool + + // If LimitHostFDTranslation is true, apply maxFillRange() constraints to + // host file descriptor mappings returned by + // CachingInodeOperations.Translate(). + LimitHostFDTranslation bool +} + // CachedFileObject is a file that may require caching. type CachedFileObject interface { // ReadToBlocksAt reads up to dsts.NumBytes() bytes from the file to dsts, @@ -128,12 +140,16 @@ type CachedFileObject interface { // WriteFromBlocksAt may return a partial write without an error. WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) - // SetMaskedAttributes sets the attributes in attr that are true in mask - // on the backing file. + // SetMaskedAttributes sets the attributes in attr that are true in + // mask on the backing file. If the mask contains only ATime or MTime + // and the CachedFileObject has an FD to the file, then this operation + // is a noop unless forceSetTimestamps is true. This avoids an extra + // RPC to the gofer in the open-read/write-close case, when the + // timestamps on the file will be updated by the host kernel for us. // // SetMaskedAttributes may be called at any point, regardless of whether // the file was opened. - SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error + SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error // Allocate allows the caller to reserve disk space for the inode. // It's equivalent to fallocate(2) with 'mode=0'. @@ -159,7 +175,7 @@ type CachedFileObject interface { // NewCachingInodeOperations returns a new CachingInodeOperations backed by // a CachedFileObject and its initial unstable attributes. -func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, forcePageCache bool) *CachingInodeOperations { +func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, opts CachingInodeOperationsOptions) *CachingInodeOperations { mfp := pgalloc.MemoryFileProviderFromContext(ctx) if mfp == nil { panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider)) @@ -167,7 +183,7 @@ func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject return &CachingInodeOperations{ backingFile: backingFile, mfp: mfp, - forcePageCache: forcePageCache, + opts: opts, attr: uattr, hostFileMapper: NewHostFileMapper(), } @@ -212,7 +228,7 @@ func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.I now := ktime.NowFromContext(ctx) masked := fs.AttrMask{Perms: true} - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil { return false } c.attr.Perms = perms @@ -234,7 +250,7 @@ func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode, UID: owner.UID.Ok(), GID: owner.GID.Ok(), } - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil { return err } if owner.UID.Ok() { @@ -270,7 +286,9 @@ func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.In AccessTime: !ts.ATimeOmit, ModificationTime: !ts.MTimeOmit, } - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil { + // Call SetMaskedAttributes with forceSetTimestamps = true to make sure + // the timestamp is updated. + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil { return err } if !ts.ATimeOmit { @@ -293,7 +311,7 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode, now := ktime.NowFromContext(ctx) masked := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: size} - if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil { c.dataMu.Unlock() return err } @@ -382,7 +400,7 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) c.dirtyAttr.Size = false // Write out cached attributes. - if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil { c.attrMu.Unlock() return err } @@ -763,7 +781,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error // and memory mappings, and false if c.cache may contain data cached from // c.backingFile. func (c *CachingInodeOperations) useHostPageCache() bool { - return !c.forcePageCache && c.backingFile.FD() >= 0 + return !c.opts.ForcePageCache && c.backingFile.FD() >= 0 } // AddMapping implements memmap.Mappable.AddMapping. @@ -784,11 +802,6 @@ func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.Mappi mf.MarkUnevictable(c, pgalloc.EvictableRange{r.Start, r.End}) } } - if c.useHostPageCache() && !usage.IncrementalMappedAccounting { - for _, r := range mapped { - usage.MemoryAccounting.Inc(r.Length(), usage.Mapped) - } - } c.mapsMu.Unlock() return nil } @@ -802,11 +815,6 @@ func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Ma c.hostFileMapper.DecRefOn(r) } if c.useHostPageCache() { - if !usage.IncrementalMappedAccounting { - for _, r := range unmapped { - usage.MemoryAccounting.Dec(r.Length(), usage.Mapped) - } - } c.mapsMu.Unlock() return } @@ -835,11 +843,15 @@ func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.Mapp func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { // Hot path. Avoid defer. if c.useHostPageCache() { + mr := optional + if c.opts.LimitHostFDTranslation { + mr = maxFillRange(required, optional) + } return []memmap.Translation{ { - Source: optional, + Source: mr, File: c, - Offset: optional.Start, + Offset: mr.Start, Perms: usermem.AnyAccess, }, }, nil @@ -985,9 +997,7 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { seg, gap = seg.NextNonEmpty() case gap.Ok() && gap.Start() < fr.End: newRange := gap.Range().Intersect(fr) - if usage.IncrementalMappedAccounting { - usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) - } + usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) seg, gap = c.refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty() default: c.refs.MergeAdjacent(fr) @@ -1008,9 +1018,7 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { for seg.Ok() && seg.Start() < fr.End { seg = c.refs.Isolate(seg, fr) if old := seg.Value(); old == 1 { - if usage.IncrementalMappedAccounting { - usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) - } + usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) seg = c.refs.Remove(seg).NextSegment() } else { seg.SetValue(old - 1) diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go index dc19255ed..129f314c8 100644 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ b/pkg/sentry/fs/fsutil/inode_cached_test.go @@ -39,7 +39,7 @@ func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.Block return srcs.NumBytes(), nil } -func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error { +func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { return nil } @@ -61,7 +61,7 @@ func TestSetPermissions(t *testing.T) { uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{ Perms: fs.FilePermsFromMode(0444), }) - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) defer iops.Release() perms := fs.FilePermsFromMode(0777) @@ -150,7 +150,7 @@ func TestSetTimestamps(t *testing.T) { ModificationTime: epoch, StatusChangeTime: epoch, } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) defer iops.Release() if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil { @@ -188,7 +188,7 @@ func TestTruncate(t *testing.T) { uattr := fs.UnstableAttr{ Size: 0, } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) defer iops.Release() if err := iops.Truncate(ctx, nil, uattr.Size); err != nil { @@ -230,7 +230,7 @@ func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.B return w.WriteFromBlocks(srcs) } -func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error { +func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { return nil } @@ -280,7 +280,7 @@ func TestRead(t *testing.T) { uattr := fs.UnstableAttr{ Size: int64(len(buf)), } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{}) defer iops.Release() // Expect the cache to be initially empty. @@ -336,7 +336,7 @@ func TestWrite(t *testing.T) { uattr := fs.UnstableAttr{ Size: int64(len(buf)), } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{}) defer iops.Release() // Expect the cache to be initially empty. diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 6b993928c..2b71ca0e1 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "gofer", diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go index 69999dc28..8f8ab5d29 100644 --- a/pkg/sentry/fs/gofer/fs.go +++ b/pkg/sentry/fs/gofer/fs.go @@ -54,6 +54,10 @@ const ( // sandbox using files backed by the gofer. If set to false, unix sockets // cannot be bound to gofer files without an overlay on top. privateUnixSocketKey = "privateunixsocket" + + // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to + // true. + limitHostFDTranslationKey = "limit_host_fd_translation" ) // defaultAname is the default attach name. @@ -134,12 +138,13 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou // opts are parsed 9p mount options. type opts struct { - fd int - aname string - policy cachePolicy - msize uint32 - version string - privateunixsocket bool + fd int + aname string + policy cachePolicy + msize uint32 + version string + privateunixsocket bool + limitHostFDTranslation bool } // options parses mount(2) data into structured options. @@ -237,6 +242,11 @@ func options(data string) (opts, error) { delete(options, privateUnixSocketKey) } + if _, ok := options[limitHostFDTranslationKey]; ok { + o.limitHostFDTranslation = true + delete(options, limitHostFDTranslationKey) + } + // Fail to attach if the caller wanted us to do something that we // don't support. if len(options) > 0 { diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 95b064aea..d918d6620 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -215,8 +215,8 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo } // SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. -func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error { - if i.skipSetAttr(mask) { +func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error { + if i.skipSetAttr(mask, forceSetTimestamps) { return nil } as, ans := attr.AccessTime.Unix() @@ -251,13 +251,14 @@ func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMa // when: // - Mask is empty // - Mask contains only attributes that cannot be set in the gofer -// - Mask contains only atime and/or mtime, and host FD exists +// - forceSetTimestamps is false and mask contains only atime and/or mtime +// and host FD exists // // Updates to atime and mtime can be skipped because cached value will be // "close enough" to host value, given that operation went directly to host FD. // Skipping atime updates is particularly important to reduce the number of // operations sent to the Gofer for readonly files. -func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool { +func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool { // First remove attributes that cannot be updated. cpy := mask cpy.Type = false @@ -277,6 +278,12 @@ func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool { return false } + // If forceSetTimestamps was passed, then we cannot skip. + if forceSetTimestamps { + return false + } + + // Skip if we have a host FD. i.handlesMu.RLock() defer i.handlesMu.RUnlock() return (i.readHandles != nil && i.readHandles.Host != nil) || diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 69d08a627..50da865c1 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -117,6 +117,11 @@ type session struct { // Flags provided to the mount. superBlockFlags fs.MountSourceFlags `state:"wait"` + // limitHostFDTranslation is the value used for + // CachingInodeOperationsOptions.LimitHostFDTranslation for all + // CachingInodeOperations created by the session. + limitHostFDTranslation bool + // connID is a unique identifier for the session connection. connID string `state:"wait"` @@ -218,8 +223,11 @@ func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p uattr := unstable(ctx, valid, attr, s.mounter, s.client) return sattr, &inodeOperations{ - fileState: fileState, - cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, s.superBlockFlags.ForcePageCache), + fileState: fileState, + cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{ + ForcePageCache: s.superBlockFlags.ForcePageCache, + LimitHostFDTranslation: s.limitHostFDTranslation, + }), } } @@ -242,13 +250,14 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF // Construct the session. s := session{ - connID: dev, - msize: o.msize, - version: o.version, - cachePolicy: o.policy, - aname: o.aname, - superBlockFlags: superBlockFlags, - mounter: mounter, + connID: dev, + msize: o.msize, + version: o.version, + cachePolicy: o.policy, + aname: o.aname, + superBlockFlags: superBlockFlags, + limitHostFDTranslation: o.limitHostFDTranslation, + mounter: mounter, } s.EnableLeakCheck("gofer.session") diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index b1080fb1a..3e532332e 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "host", diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 679d8321a..a6e4a09e3 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -114,7 +114,7 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo } // SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. -func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error { +func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error { if mask.Empty() { return nil } @@ -163,7 +163,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err return unstableAttr(i.mops, &s), nil } -// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. +// Allocate implements fsutil.CachedFileObject.Allocate. func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error { return syscall.Fallocate(i.FD(), 0, offset, length) } @@ -200,8 +200,10 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool, // Build the fs.InodeOperations. uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s) iops := &inodeOperations{ - fileState: fileState, - cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, msrc.Flags.ForcePageCache), + fileState: fileState, + cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{ + ForcePageCache: msrc.Flags.ForcePageCache, + }), } // Return the fs.Inode. diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 2526412a4..90331e3b2 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -43,12 +43,15 @@ type TTYFileOperations struct { // fgProcessGroup is the foreground process group that is currently // connected to this TTY. fgProcessGroup *kernel.ProcessGroup + + termios linux.KernelTermios } // newTTYFile returns a new fs.File that wraps a TTY FD. func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File { return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{ fileOperations: fileOperations{iops: iops}, + termios: linux.DefaultSlaveTermios, }) } @@ -97,9 +100,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme t.mu.Lock() defer t.mu.Unlock() - // Are we allowed to do the write? - if err := t.checkChange(ctx, linux.SIGTTOU); err != nil { - return 0, err + // Check whether TOSTOP is enabled. This corresponds to the check in + // drivers/tty/n_tty.c:n_tty_write(). + if t.termios.LEnabled(linux.TOSTOP) { + if err := t.checkChange(ctx, linux.SIGTTOU); err != nil { + return 0, err + } } return t.fileOperations.Write(ctx, file, src, offset) } @@ -144,6 +150,9 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO return 0, err } err := ioctlSetTermios(fd, ioctl, &termios) + if err == nil { + t.termios.FromTermios(termios) + } return 0, err case linux.TIOCGPGRP: diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 246b97161..5a388dad1 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "fmt" "strings" "gvisor.dev/gvisor/pkg/abi/linux" @@ -207,6 +208,11 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name } func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) { + // Sanity check. + if parent.Inode.overlay == nil { + panic(fmt.Sprintf("overlayCreate called with non-overlay parent inode (parent InodeOperations type is %T)", parent.Inode.InodeOperations)) + } + // Dirent.Create takes renameMu if the Inode is an overlay Inode. if err := copyUpLockedForRename(ctx, parent); err != nil { return nil, err diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index c7f4e2d13..ba3e0233d 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "sync/atomic" @@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i } // WriteTo implements FileOperations.WriteTo. -func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } @@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error { } // ReadFrom implements FileOperations.ReadFrom. -func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 08d7c0c57..5a7a5b8cd 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "lock_range", diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index 9b713e785..ac0398bd9 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -171,8 +171,6 @@ type MountNamespace struct { // NewMountNamespace returns a new MountNamespace, with the provided node at the // root, and the given cache size. A root must always be provided. func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error) { - creds := auth.CredentialsFromContext(ctx) - // Set the root dirent and id on the root mount. The reference returned from // NewDirent will be donated to the MountNamespace constructed below. d := NewDirent(ctx, root, "/") @@ -181,6 +179,7 @@ func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error d: newRootMount(1, d), } + creds := auth.CredentialsFromContext(ctx) mns := MountNamespace{ userns: creds.UserNamespace, root: d, diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 70ed854a8..1c93e8886 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "proc", @@ -31,7 +33,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/log", "//pkg/sentry/context", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 9adb23608..f70239449 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -17,10 +17,10 @@ package proc import ( "bytes" "fmt" + "io" "time" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -28,9 +28,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) // newNet creates a new proc net entry. @@ -57,15 +59,14 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")), "route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")), "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), - "udp": newStaticProcInode(ctx, msrc, []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops")), - - "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), + "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), + "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), } if s.SupportsIPv6() { contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc) contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte("")) - contents["tcp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) + contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc) contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) } } @@ -216,7 +217,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s for _, se := range n.k.ListSockets() { s := se.Sock.Get() if s == nil { - log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock) + log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) continue } sfile := s.(*fs.File) @@ -297,20 +298,66 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s return data, 0 } -// netTCP implements seqfile.SeqSource for /proc/net/tcp. -// -// +stateify savable -type netTCP struct { - k *kernel.Kernel +func networkToHost16(n uint16) uint16 { + // n is in network byte order, so is big-endian. The most-significant byte + // should be stored in the lower address. + // + // We manually inline binary.BigEndian.Uint16() because Go does not support + // non-primitive consts, so binary.BigEndian is a (mutable) var, so calls to + // binary.BigEndian.Uint16() require a read of binary.BigEndian and an + // interface method call, defeating inlining. + buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)} + return usermem.ByteOrder.Uint16(buf[:]) } -// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. -func (*netTCP) NeedsUpdate(generation int64) bool { - return true +func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { + switch family { + case linux.AF_INET: + var a linux.SockAddrInet + if i != nil { + a = *i.(*linux.SockAddrInet) + } + + // linux.SockAddrInet.Port is stored in the network byte order and is + // printed like a number in host byte order. Note that all numbers in host + // byte order are printed with the most-significant byte first when + // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. + port := networkToHost16(a.Port) + + // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order + // (i.e. most-significant byte in index 0). Linux represents this as a + // __be32 which is a typedef for an unsigned int, and is printed with + // %X. This means that for a little-endian machine, Linux prints the + // least-significant byte of the address first. To emulate this, we first + // invert the byte order for the address using usermem.ByteOrder.Uint32, + // which makes it have the equivalent encoding to a __be32 on a little + // endian machine. Note that this operation is a no-op on a big endian + // machine. Then similar to Linux, we format it with %X, which will print + // the most-significant byte of the __be32 address first, which is now + // actually the least-significant byte of the original address in + // linux.SockAddrInet.Addr on little endian machines, due to the conversion. + addr := usermem.ByteOrder.Uint32(a.Addr[:]) + + fmt.Fprintf(w, "%08X:%04X ", addr, port) + case linux.AF_INET6: + var a linux.SockAddrInet6 + if i != nil { + a = *i.(*linux.SockAddrInet6) + } + + port := networkToHost16(a.Port) + addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4]) + addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8]) + addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12]) + addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16]) + fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port) + } } -// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. -func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { +func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kernel.Kernel, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) { + // t may be nil here if our caller is not part of a task goroutine. This can + // happen for example if we're here for "sentryctl cat". When t is nil, + // degrade gracefully and retrieve what we can. t := kernel.TaskFromContext(ctx) if h != nil { @@ -318,10 +365,10 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se } var buf bytes.Buffer - for _, se := range n.k.ListSockets() { + for _, se := range k.ListSockets() { s := se.Sock.Get() if s == nil { - log.Debugf("Couldn't resolve weakref %+v in socket table, racing with destruction?", se.Sock) + log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) continue } sfile := s.(*fs.File) @@ -329,7 +376,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se if !ok { panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } - if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) { + if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { s.DecRef() // Not tcp4 sockets. continue @@ -343,27 +390,23 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se // Field: sl; entry number. fmt.Fprintf(&buf, "%4d: ", se.ID) - portBuf := make([]byte, 2) - // Field: local_adddress. - var localAddr linux.SockAddrInet - if local, _, err := sops.GetSockName(t); err == nil { - localAddr = *local.(*linux.SockAddrInet) + var localAddr linux.SockAddr + if t != nil { + if local, _, err := sops.GetSockName(t); err == nil { + localAddr = local + } } - binary.LittleEndian.PutUint16(portBuf, localAddr.Port) - fmt.Fprintf(&buf, "%08X:%04X ", - binary.LittleEndian.Uint32(localAddr.Addr[:]), - portBuf) + writeInetAddr(&buf, fa, localAddr) // Field: rem_address. - var remoteAddr linux.SockAddrInet - if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = *remote.(*linux.SockAddrInet) + var remoteAddr linux.SockAddr + if t != nil { + if remote, _, err := sops.GetPeerName(t); err == nil { + remoteAddr = remote + } } - binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port) - fmt.Fprintf(&buf, "%08X:%04X ", - binary.LittleEndian.Uint32(remoteAddr.Addr[:]), - portBuf) + writeInetAddr(&buf, fa, remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) @@ -386,7 +429,8 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se log.Warningf("Failed to retrieve unstable attr for socket file: %v", err) fmt.Fprintf(&buf, "%5d ", 0) } else { - fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow())) + creds := auth.CredentialsFromContext(ctx) + fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow())) } // Field: timeout; number of unanswered 0-window probes. @@ -428,7 +472,165 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se data := []seqfile.SeqData{ { - Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n"), + Buf: header, + Handle: n, + }, + { + Buf: buf.Bytes(), + Handle: n, + }, + } + return data, 0 +} + +// netTCP implements seqfile.SeqSource for /proc/net/tcp. +// +// +stateify savable +type netTCP struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET, header) +} + +// netTCP6 implements seqfile.SeqSource for /proc/net/tcp6. +// +// +stateify savable +type netTCP6 struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP6) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET6, header) +} + +// netUDP implements seqfile.SeqSource for /proc/net/udp. +// +// +stateify savable +type netUDP struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netUDP) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + // t may be nil here if our caller is not part of a task goroutine. This can + // happen for example if we're here for "sentryctl cat". When t is nil, + // degrade gracefully and retrieve what we can. + t := kernel.TaskFromContext(ctx) + + if h != nil { + return nil, 0 + } + + var buf bytes.Buffer + for _, se := range n.k.ListSockets() { + s := se.Sock.Get() + if s == nil { + log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) + continue + } + sfile := s.(*fs.File) + sops, ok := sfile.FileOperations.(socket.Socket) + if !ok { + panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) + } + if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM { + s.DecRef() + // Not udp4 socket. + continue + } + + // For Linux's implementation, see net/ipv4/udp.c:udp4_format_sock(). + + // Field: sl; entry number. + fmt.Fprintf(&buf, "%5d: ", se.ID) + + // Field: local_adddress. + var localAddr linux.SockAddrInet + if t != nil { + if local, _, err := sops.GetSockName(t); err == nil { + localAddr = *local.(*linux.SockAddrInet) + } + } + writeInetAddr(&buf, linux.AF_INET, &localAddr) + + // Field: rem_address. + var remoteAddr linux.SockAddrInet + if t != nil { + if remote, _, err := sops.GetPeerName(t); err == nil { + remoteAddr = *remote.(*linux.SockAddrInet) + } + } + writeInetAddr(&buf, linux.AF_INET, &remoteAddr) + + // Field: state; socket state. + fmt.Fprintf(&buf, "%02X ", sops.State()) + + // Field: tx_queue, rx_queue; number of packets in the transmit and + // receive queue. Unimplemented. + fmt.Fprintf(&buf, "%08X:%08X ", 0, 0) + + // Field: tr, tm->when. Always 0 for UDP. + fmt.Fprintf(&buf, "%02X:%08X ", 0, 0) + + // Field: retrnsmt. Always 0 for UDP. + fmt.Fprintf(&buf, "%08X ", 0) + + // Field: uid. + uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx) + if err != nil { + log.Warningf("Failed to retrieve unstable attr for socket file: %v", err) + fmt.Fprintf(&buf, "%5d ", 0) + } else { + creds := auth.CredentialsFromContext(ctx) + fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow())) + } + + // Field: timeout. Always 0 for UDP. + fmt.Fprintf(&buf, "%8d ", 0) + + // Field: inode. + fmt.Fprintf(&buf, "%8d ", sfile.InodeID()) + + // Field: ref; reference count on the socket inode. Don't count the ref + // we obtain while deferencing the weakref to this socket. + fmt.Fprintf(&buf, "%d ", sfile.ReadRefs()-1) + + // Field: Socket struct address. Redacted due to the same reason as + // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. + fmt.Fprintf(&buf, "%#016p ", (*socket.Socket)(nil)) + + // Field: drops; number of dropped packets. Unimplemented. + fmt.Fprintf(&buf, "%d", 0) + + fmt.Fprintf(&buf, "\n") + + s.DecRef() + } + + data := []seqfile.SeqData{ + { + Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops \n"), Handle: n, }, { diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 0ef13f2f5..56e92721e 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -230,7 +230,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent // But for whatever crazy reason, you can still walk to the given node. for _, tg := range rpf.iops.pidns.ThreadGroups() { if leader := tg.Leader(); leader != nil { - name := strconv.FormatUint(uint64(tg.ID()), 10) + name := strconv.FormatUint(uint64(rpf.iops.pidns.IDOfThreadGroup(tg)), 10) m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice) names = append(names, name) } diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 20c3eefc8..76433c7d0 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "seqfile", diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 516efcc4c..d0f351e5a 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "ramfs", diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index eed1c2854..311798811 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -18,7 +18,6 @@ import ( "io" "sync/atomic" - "gvisor.dev/gvisor/pkg/secio" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserror" ) @@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } // Check whether or not the objects being sliced are stream-oriented - // (i.e. pipes or sockets). If yes, we elide checks and offset locks. - srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr) - dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr) + // (i.e. pipes or sockets). For all stream-oriented files and files + // where a specific offiset is not request, we acquire the file mutex. + // This has two important side effects. First, it provides the standard + // protection against concurrent writes that would mutate the offset. + // Second, it prevents Splice deadlocks. Only internal anonymous files + // implement the ReadFrom and WriteTo methods directly, and since such + // anonymous files are referred to by a unique fs.File object, we know + // that the file mutex takes strict precedence over internal locks. + // Since we enforce lock ordering here, we can't deadlock by using + // using a file in two different splice operations simultaneously. + srcPipe := !IsRegular(src.Dirent.Inode.StableAttr) + dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr) + dstAppend := !dstPipe && dst.Flags().Append + srcLock := srcPipe || !opts.SrcOffset + dstLock := dstPipe || !opts.DstOffset || dstAppend - if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset { + switch { + case srcLock && dstLock: switch { case dst.UniqueID < src.UniqueID: // Acquire dst first. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() if !src.mu.Lock(ctx) { + dst.mu.Unlock() return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() case dst.UniqueID > src.UniqueID: // Acquire src first. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() if !dst.mu.Lock(ctx) { + src.mu.Unlock() return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() case dst.UniqueID == src.UniqueID: // Acquire only one lock; it's the same file. This is a // bit of a edge case, but presumably it's possible. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() + srcLock = false // Only need one unlock. } // Use both offsets (locked). opts.DstStart = dst.offset opts.SrcStart = src.offset - } else if !dstPipe && !opts.DstOffset { + case dstLock: // Acquire only dst. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() opts.DstStart = dst.offset // Safe: locked. - } else if !srcPipe && !opts.SrcOffset { + case srcLock: // Acquire only src. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() opts.SrcStart = src.offset // Safe: locked. } - // Check append-only mode and the limit. - if !dstPipe { + var err error + if dstAppend { unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append) defer unlock() - if dst.Flags().Append { - if opts.DstOffset { - // We need to acquire the lock. - if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted - } - defer dst.mu.Unlock() - } - // Figure out the appropriate offset to use. - if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil { - return 0, err - } - } + // Figure out the appropriate offset to use. + err = dst.offsetForAppend(ctx, &opts.DstStart) + } + if err == nil && !dstPipe { // Enforce file limits. limit, ok := dst.checkLimit(ctx, opts.DstStart) switch { case ok && limit == 0: - return 0, syserror.ErrExceedsFileSizeLimit + err = syserror.ErrExceedsFileSizeLimit case ok && limit < opts.Length: opts.Length = limit // Cap the write. } } + if err != nil { + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return 0, err + } - // Attempt to do a WriteTo; this is likely the most efficient. - // - // The underlying implementation may be able to donate buffers. - newOpts := SpliceOpts{ - Length: opts.Length, - SrcStart: opts.SrcStart, - SrcOffset: !srcPipe, - Dup: opts.Dup, - DstStart: opts.DstStart, - DstOffset: !dstPipe, + // Construct readers and writers for the splice. This is used to + // provide a safer locking path for the WriteTo/ReadFrom operations + // (since they will otherwise go through public interface methods which + // conflict with locking done above), and simplifies the fallback path. + w := &lockedWriter{ + Ctx: ctx, + File: dst, + Offset: opts.DstStart, } - n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts) - if n == 0 && err != nil { - // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also - // be more efficient than a copy if buffers are cached or readily - // available. (It's unlikely that they can actually be donate - n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts) + r := &lockedReader{ + Ctx: ctx, + File: src, + Offset: opts.SrcStart, } - if n == 0 && err != nil { - // If we've failed up to here, and at least one of the sources - // is a pipe or socket, then we can't properly support dup. - // Return an error indicating that this operation is not - // supported. - if (srcPipe || dstPipe) && newOpts.Dup { - return 0, syserror.EINVAL - } - // We failed to splice the files. But that's fine; we just fall - // back to a slow path in this case. This copies without doing - // any mode changes, so should still be more efficient. - var ( - r io.Reader - w io.Writer - ) - fw := &lockedWriter{ - Ctx: ctx, - File: dst, - } - if newOpts.DstOffset { - // Use the provided offset. - w = secio.NewOffsetWriter(fw, newOpts.DstStart) - } else { - // Writes will proceed with no offset. - w = fw - } - fr := &lockedReader{ - Ctx: ctx, - File: src, - } - if newOpts.SrcOffset { - // Limit to the given offset and length. - r = io.NewSectionReader(fr, opts.SrcStart, opts.Length) - } else { - // Limit just to the given length. - r = &io.LimitedReader{fr, opts.Length} - } + // Attempt to do a WriteTo; this is likely the most efficient. + n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) + if n == 0 && err == syserror.ENOSYS && !opts.Dup { + // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be + // more efficient than a copy if buffers are cached or readily + // available. (It's unlikely that they can actually be donated). + n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length) + } - // Copy between the two. - n, err = io.Copy(w, r) + // Support one last fallback option, but only if at least one of + // the source and destination are regular files. This is because + // if we block at some point, we could lose data. If the source is + // not a pipe then reading is not destructive; if the destination + // is a regular file, then it is guaranteed not to block writing. + if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) { + // Fallback to an in-kernel copy. + n, err = io.Copy(w, &io.LimitedReader{ + R: r, + N: opts.Length, + }) } // Update offsets, if required. @@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } } + // Drop locks. + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return n, err } diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index 59403d9db..f8bf663bb 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -141,9 +141,10 @@ func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, i } // Notify implements ktime.TimerListener.Notify. -func (t *TimerOperations) Notify(exp uint64) { +func (t *TimerOperations) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { atomic.AddUint64(&t.val, exp) t.events.Notify(waiter.EventIn) + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 8f7eb5757..11b680929 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tmpfs", diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 159fb7c08..69089c8a8 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -324,7 +324,7 @@ type Fifo struct { // NewFifo creates a new named pipe. func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { // First create a pipe. - p := pipe.NewPipe(ctx, true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) // Build pipe InodeOperations. iops := pipe.NewInodeOperations(ctx, perms, p) diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 291164986..25811f668 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tty", diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index a41101339..b0c286b7a 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") go_template_instance( @@ -79,7 +81,7 @@ go_test( "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", - "//runsc/test/testutil", + "//runsc/testutil", "@com_github_google_go-cmp//cmp:go_default_library", "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 9fddb4c4c..bfc46dfa6 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index b51f3e18d..91802dc1e 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba } if !cb.Handle(vfs.Dirent{ - Name: child.diskDirent.FileName(), - Type: fs.ToDirentType(childType), - Ino: uint64(child.diskDirent.Inode()), - Off: fd.off, + Name: child.diskDirent.FileName(), + Type: fs.ToDirentType(childType), + Ino: uint64(child.diskDirent.Inode()), + NextOff: fd.off + 1, }) { dir.childList.InsertBefore(child, fd.iter) return nil @@ -301,8 +301,8 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in return offset, nil } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { // mmap(2) specifies that EACCESS should be returned for non-regular file fds. return syserror.EACCES } diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 907d35b7e..2d50e30aa 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "disklayout", diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 49b57a2d6..1aa2bd6a4 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -33,7 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/runsc/test/testutil" + "gvisor.dev/gvisor/runsc/testutil" ) const ( @@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) { // Ignore the inode number and offset of dirents because those are likely to // change as the underlying image changes. cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool { - return p.String() == "Ino" || p.String() == "Off" + return p.String() == "Ino" || p.String() == "NextOff" }, cmp.Ignore()) if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" { t.Errorf("dirents mismatch (-want +got):\n%s", diff) diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go index a0065343b..4d18b28cb 100644 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -43,9 +43,6 @@ func (fd *fileDescription) inode() *inode { return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode } -// OnClose implements vfs.FileDescriptionImpl.OnClose. -func (fd *fileDescription) OnClose() error { return nil } - // StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { return fd.flags, nil diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index ffc76ba5b..aec33e00a 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -152,8 +152,8 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( return offset, nil } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { // TODO(b/134676337): Implement mmap(2). return syserror.ENODEV } diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index e06548a98..bdf8705c1 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -105,7 +105,7 @@ func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int6 return 0, syserror.EBADF } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return syserror.EBADF } diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index d2450e810..7e364c5fd 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go index c52dc781c..0bd82e480 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/memfs/directory.go @@ -32,7 +32,7 @@ type directory struct { childList dentryList } -func (fs *filesystem) newDirectory(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode { dir := &directory{} dir.inode.init(dir, fs, creds, mode) dir.inode.nlink = 2 // from "." and parent directory or ".." for root @@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if fd.off == 0 { if !cb.Handle(vfs.Dirent{ - Name: ".", - Type: linux.DT_DIR, - Ino: vfsd.Impl().(*dentry).inode.ino, - Off: 0, + Name: ".", + Type: linux.DT_DIR, + Ino: vfsd.Impl().(*dentry).inode.ino, + NextOff: 1, }) { return nil } @@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if fd.off == 1 { parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode if !cb.Handle(vfs.Dirent{ - Name: "..", - Type: parentInode.direntType(), - Ino: parentInode.ino, - Off: 1, + Name: "..", + Type: parentInode.direntType(), + Ino: parentInode.ino, + NextOff: 2, }) { return nil } @@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba // Skip other directoryFD iterators. if child.inode != nil { if !cb.Handle(vfs.Dirent{ - Name: child.vfsd.Name(), - Type: child.inode.direntType(), - Ino: child.inode.ino, - Off: fd.off, + Name: child.vfsd.Name(), + Type: child.inode.direntType(), + Ino: child.inode.ino, + NextOff: fd.off + 1, }) { dir.childList.InsertBefore(child, fd.iter) return nil diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go index 45cd42b3e..b78471c0f 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/memfs/memfs.go @@ -137,7 +137,7 @@ type inode struct { impl interface{} // immutable } -func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode uint16) { +func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) { i.refs = 1 i.mode = uint32(mode) i.uid = uint32(creds.EffectiveKUID) diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go index 55f869798..b7f4853b3 100644 --- a/pkg/sentry/fsimpl/memfs/regular_file.go +++ b/pkg/sentry/fsimpl/memfs/regular_file.go @@ -37,7 +37,7 @@ type regularFile struct { dataLen int64 } -func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { file := ®ularFile{} file.inode.init(file, fs, creds, mode) file.inode.nlink = 1 // from parent directory diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 3d8a4deaf..ade6ac946 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD index f989f2f8b..359468ccc 100644 --- a/pkg/sentry/hostcpu/BUILD +++ b/pkg/sentry/hostcpu/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,6 +7,7 @@ go_library( name = "hostcpu", srcs = [ "getcpu_amd64.s", + "getcpu_arm64.s", "hostcpu.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu", diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s new file mode 100644 index 000000000..caf9abb89 --- /dev/null +++ b/pkg/sentry/hostcpu/getcpu_arm64.s @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for +// the lack of an optimazed way of getting the current CPU number on arm64. + +// func GetCPU() (cpu uint32) +TEXT ·GetCPU(SB), NOSPLIT, $0-4 + MOVW ZR, cpu+0(FP) + MOVD $cpu+0(FP), R0 + MOVD $0x0, R1 // unused + MOVD $0x0, R2 // unused + MOVD $0xA8, R8 // SYS_GETCPU + SVC + RET diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 41bee9a22..aba2414d4 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,9 +1,11 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "pending_signals_list", @@ -83,6 +85,12 @@ proto_library( deps = ["//pkg/sentry/arch:registers_proto"], ) +cc_proto_library( + name = "uncaught_signal_cc_proto", + visibility = ["//visibility:public"], + deps = [":uncaught_signal_proto"], +) + go_proto_library( name = "uncaught_signal_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto", diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index f46c43128..65427b112 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "epoll_list", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index 1c5f979d4..983ca67ed 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "eventfd", diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 6a31dc044..41f44999c 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "atomicptr_bucket", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8c1f79ab5..3cda03891 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -24,6 +24,7 @@ // TaskSet.mu // SignalHandlers.mu // Task.mu +// runningTasksMu // // Locking SignalHandlers.mu in multiple SignalHandlers requires locking // TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same @@ -135,6 +136,22 @@ type Kernel struct { // syslog is the kernel log. syslog syslog + // runningTasksMu synchronizes disable/enable of cpuClockTicker when + // the kernel is idle (runningTasks == 0). + // + // runningTasksMu is used to exclude critical sections when the timer + // disables itself and when the first active task enables the timer, + // ensuring that tasks always see a valid cpuClock value. + runningTasksMu sync.Mutex `state:"nosave"` + + // runningTasks is the total count of tasks currently in + // TaskGoroutineRunningSys or TaskGoroutineRunningApp. i.e., they are + // not blocked or stopped. + // + // runningTasks must be accessed atomically. Increments from 0 to 1 are + // further protected by runningTasksMu (see incRunningTasks). + runningTasks int64 + // cpuClock is incremented every linux.ClockTick. cpuClock is used to // measure task CPU usage, since sampling monotonicClock twice on every // syscall turns out to be unreasonably expensive. This is similar to how @@ -150,6 +167,22 @@ type Kernel struct { // cpuClockTicker increments cpuClock. cpuClockTicker *ktime.Timer `state:"nosave"` + // cpuClockTickerDisabled indicates that cpuClockTicker has been + // disabled because no tasks are running. + // + // cpuClockTickerDisabled is protected by runningTasksMu. + cpuClockTickerDisabled bool + + // cpuClockTickerSetting is the ktime.Setting of cpuClockTicker at the + // point it was disabled. It is cached here to avoid a lock ordering + // violation with cpuClockTicker.mu when runningTaskMu is held. + // + // cpuClockTickerSetting is only valid when cpuClockTickerDisabled is + // true. + // + // cpuClockTickerSetting is protected by runningTasksMu. + cpuClockTickerSetting ktime.Setting + // fdMapUids is an ever-increasing counter for generating FDTable uids. // // fdMapUids is mutable, and is accessed using atomic memory operations. @@ -912,6 +945,102 @@ func (k *Kernel) resumeTimeLocked() { } } +func (k *Kernel) incRunningTasks() { + for { + tasks := atomic.LoadInt64(&k.runningTasks) + if tasks != 0 { + // Standard case. Simply increment. + if !atomic.CompareAndSwapInt64(&k.runningTasks, tasks, tasks+1) { + continue + } + return + } + + // Transition from 0 -> 1. Synchronize with other transitions and timer. + k.runningTasksMu.Lock() + tasks = atomic.LoadInt64(&k.runningTasks) + if tasks != 0 { + // We're no longer the first task, no need to + // re-enable. + atomic.AddInt64(&k.runningTasks, 1) + k.runningTasksMu.Unlock() + return + } + + if !k.cpuClockTickerDisabled { + // Timer was never disabled. + atomic.StoreInt64(&k.runningTasks, 1) + k.runningTasksMu.Unlock() + return + } + + // We need to update cpuClock for all of the ticks missed while we + // slept, and then re-enable the timer. + // + // The Notify in Swap isn't sufficient. kernelCPUClockTicker.Notify + // always increments cpuClock by 1 regardless of the number of + // expirations as a heuristic to avoid over-accounting in cases of CPU + // throttling. + // + // We want to cover the normal case, when all time should be accounted, + // so we increment for all expirations. Throttling is less concerning + // here because the ticker is only disabled from Notify. This means + // that Notify must schedule and compensate for the throttled period + // before the timer is disabled. Throttling while the timer is disabled + // doesn't matter, as nothing is running or reading cpuClock anyways. + // + // S/R also adds complication, as there are two cases. Recall that + // monotonicClock will jump forward on restore. + // + // 1. If the ticker is enabled during save, then on Restore Notify is + // called with many expirations, covering the time jump, but cpuClock + // is only incremented by 1. + // + // 2. If the ticker is disabled during save, then after Restore the + // first wakeup will call this function and cpuClock will be + // incremented by the number of expirations across the S/R. + // + // These cause very different value of cpuClock. But again, since + // nothing was running while the ticker was disabled, those differences + // don't matter. + setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now()) + if exp > 0 { + atomic.AddUint64(&k.cpuClock, exp) + } + + // Now that cpuClock is updated it is safe to allow other tasks to + // transition to running. + atomic.StoreInt64(&k.runningTasks, 1) + + // N.B. we must unlock before calling Swap to maintain lock ordering. + // + // cpuClockTickerDisabled need not wait until after Swap to become + // true. It is sufficient that the timer *will* be enabled. + k.cpuClockTickerDisabled = false + k.runningTasksMu.Unlock() + + // This won't call Notify (unless it's been ClockTick since setting.At + // above). This means we skip the thread group work in Notify. However, + // since nothing was running while we were disabled, none of the timers + // could have expired. + k.cpuClockTicker.Swap(setting) + + return + } +} + +func (k *Kernel) decRunningTasks() { + tasks := atomic.AddInt64(&k.runningTasks, -1) + if tasks < 0 { + panic(fmt.Sprintf("Invalid running count %d", tasks)) + } + + // Nothing to do. The next CPU clock tick will disable the timer if + // there is still nothing running. This provides approximately one tick + // of slack in which we can switch back and forth between idle and + // active without an expensive transition. +} + // WaitExited blocks until all tasks in k have exited. func (k *Kernel) WaitExited() { k.tasks.liveGoroutines.Wait() diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index ebcfaa619..d7a7d1169 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -24,6 +25,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "memory_events_cc_proto", + visibility = ["//visibility:public"], + deps = [":memory_events_proto"], +) + go_proto_library( name = "memory_events_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto", diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 4d15cca85..2ce8952e2 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "buffer_list", diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 69ef2a720..95bee2d37 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" @@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } +// WriteFromReader writes to the buffer from an io.Reader. +func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) { + dst := b.data[b.write:] + if count < int64(len(dst)) { + dst = b.data[b.write:][:count] + } + n, err := r.Read(dst) + b.write += n + return int64(n), err +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write])) @@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return n, err } +// ReadToWriter reads from the buffer into an io.Writer. +func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) { + src := b.data[b.read:b.write] + if count < int64(len(src)) { + src = b.data[b.read:][:count] + } + n, err := w.Write(src) + if !dup { + b.read += n + } + return int64(n), err +} + // bufferPool is a pool for buffers. var bufferPool = sync.Pool{ New: func() interface{} { diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index adbad7764..16fa80abe 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs. } func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(contexttest.Context(t), true, DefaultPipeSize, usermem.PageSize) + return NewPipe(true, DefaultPipeSize, usermem.PageSize) } func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(contexttest.Context(t), false, DefaultPipeSize, usermem.PageSize) + return NewPipe(false, DefaultPipeSize, usermem.PageSize) } // assertRecvBlocks ensures that a recv attempt on c blocks for at least diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 247e2928e..8e4e8e82e 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -99,7 +98,7 @@ type Pipe struct { // NewPipe initializes and returns a pipe. // // N.B. The size and atomicIOBytes will be bounded. -func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { +func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } @@ -122,7 +121,7 @@ func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) // NewConnectedPipe initializes a pipe and returns a pair of objects // representing the read and write ends of the pipe. func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { - p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes) + p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes) // Build an fs.Dirent for the pipe which will be shared by both // returned files. @@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F } } +type readOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit limits subsequence reads. + limit func(int64) + + // read performs the actual read operation. + read func(*buffer) (int64, error) +} + // read reads data from the pipe into dst and returns the number of bytes // read, or returns ErrWouldBlock if the pipe is empty. // // Precondition: this pipe must have readers. -func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) { +func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { // Don't block for a zero-length read even if the pipe is empty. - if dst.NumBytes() == 0 { + if ops.left() == 0 { return 0, nil } @@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Limit how much we consume. - if dst.NumBytes() > p.size { - dst = dst.TakeFirst64(p.size) + if ops.left() > p.size { + ops.limit(p.size) } done := int64(0) - for dst.NumBytes() > 0 { + for ops.left() > 0 { // Pop the first buffer. first := p.data.Front() if first == nil { @@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Copy user data. - n, err := dst.CopyOutFrom(ctx, first) + n, err := ops.read(first) done += int64(n) p.size -= n - dst = dst.DropFirst64(n) // Empty buffer? if first.Empty() { @@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) return done, nil } +// dup duplicates all data from this pipe into the given writer. +// +// There is no blocking behavior implemented here. The writer may propagate +// some blocking error. All the writes must be complete writes. +func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Is the pipe empty? + if p.size == 0 { + if !p.HasWriters() { + // See above. + return 0, nil + } + return 0, syserror.ErrWouldBlock + } + + // Limit how much we consume. + if ops.left() > p.size { + ops.limit(p.size) + } + + done := int64(0) + for buf := p.data.Front(); buf != nil; buf = buf.Next() { + n, err := ops.read(buf) + done += n + if err != nil { + return done, err + } + } + + return done, nil +} + +type writeOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit should limit subsequent writes. + limit func(int64) + + // write should write to the provided buffer. + write func(*buffer) (int64, error) +} + // write writes data from sv into the pipe and returns the number of bytes // written. If no bytes are written because the pipe is full (or has less than // atomicIOBytes free capacity), write returns ErrWouldBlock. // // Precondition: this pipe must have writers. -func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) { +func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() @@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. - wanted := src.NumBytes() + wanted := ops.left() if avail := p.max - p.size; wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } - // Limit to the available capacity. - src = src.TakeFirst64(avail) + ops.limit(avail) } done := int64(0) - for src.NumBytes() > 0 { + for ops.left() > 0 { // Need a new buffer? last := p.data.Back() if last == nil || last.Full() { @@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) } // Copy user data. - n, err := src.CopyInTo(ctx, last) + n, err := ops.write(last) done += int64(n) p.size += n - src = src.DropFirst64(n) // Handle errors. if err != nil { diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index f69dbf27b..7c307f013 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "math" "syscall" @@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() { // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, dst) + n, err := rw.Pipe.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return rw.Pipe.dup(ctx, ops) + } + n, err := rw.Pipe.read(ctx, ops) if n > 0 { rw.Pipe.Notify(waiter.EventOut) } @@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, src) + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) if n > 0 { rw.Pipe.Notify(waiter.EventIn) } diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go index c5d095af7..2e861a5a8 100644 --- a/pkg/sentry/kernel/posixtimer.go +++ b/pkg/sentry/kernel/posixtimer.go @@ -117,9 +117,9 @@ func (it *IntervalTimer) signalRejectedLocked() { } // Notify implements ktime.TimerListener.Notify. -func (it *IntervalTimer) Notify(exp uint64) { +func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { if it.target == nil { - return + return ktime.Setting{}, false } it.target.tg.pidns.owner.mu.RLock() @@ -129,7 +129,7 @@ func (it *IntervalTimer) Notify(exp uint64) { if it.sigpending { it.overrunCur += exp - return + return ktime.Setting{}, false } // sigpending must be set before sendSignalTimerLocked() so that it can be @@ -148,6 +148,8 @@ func (it *IntervalTimer) Notify(exp uint64) { if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil { it.signalRejectedLocked() } + + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. Users of Timer should call diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD index 1725b8562..98ea7a0d8 100644 --- a/pkg/sentry/kernel/sched/BUILD +++ b/pkg/sentry/kernel/sched/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 36edf10f3..80e5e5da3 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "waiter_list", diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index e5f297478..047b5214d 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -328,8 +328,14 @@ func (tg *ThreadGroup) createSession() error { childTG.processGroup.incRefWithParent(pg) childTG.processGroup.decRefWithParent(oldParentPG) }) - tg.processGroup.decRefWithParent(oldParentPG) + // If tg.processGroup is an orphan, decRefWithParent will lock + // the signal mutex of each thread group in tg.processGroup. + // However, tg's signal mutex may already be locked at this + // point. We change tg's process group before calling + // decRefWithParent to avoid locking tg's signal mutex twice. + oldPG := tg.processGroup tg.processGroup = pg + oldPG.decRefWithParent(oldParentPG) } else { // The current process group may be nil only in the case of an // unparented thread group (i.e. the init process). This would diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD new file mode 100644 index 000000000..50b69d154 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -0,0 +1,22 @@ +package(licenses = ["notice"]) + +load("//tools/go_stateify:defs.bzl", "go_library") + +go_library( + name = "signalfd", + srcs = ["signalfd.go"], + importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/sentry/context", + "//pkg/sentry/fs", + "//pkg/sentry/fs/anon", + "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", + "//pkg/sentry/usermem", + "//pkg/syserror", + "//pkg/waiter", + ], +) diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go new file mode 100644 index 000000000..4b08d7d72 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -0,0 +1,140 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package signalfd provides an implementation of signal file descriptors. +package signalfd + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/anon" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SignalOperations represent a file with signalfd semantics. +// +// +stateify savable +type SignalOperations struct { + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FilePipeSeek `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoFsync `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + // target is the original task target. + // + // The semantics here are a bit broken. Linux will always use current + // for all reads, regardless of where the signalfd originated. We can't + // do exactly that because we need to plumb the context through + // EventRegister in order to support proper blocking behavior. This + // will undoubtedly become very complicated quickly. + target *kernel.Task + + // mu protects below. + mu sync.Mutex `state:"nosave"` + + // mask is the signal mask. Protected by mu. + mask linux.SignalSet +} + +// New creates a new signalfd object with the supplied mask. +func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // No task context? Not valid. + return nil, syserror.EINVAL + } + // name matches fs/signalfd.c:signalfd4. + dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]") + return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{ + target: t, + mask: mask, + }), nil +} + +// Release implements fs.FileOperations.Release. +func (s *SignalOperations) Release() {} + +// Mask returns the signal mask. +func (s *SignalOperations) Mask() linux.SignalSet { + s.mu.Lock() + mask := s.mask + s.mu.Unlock() + return mask +} + +// SetMask sets the signal mask. +func (s *SignalOperations) SetMask(mask linux.SignalSet) { + s.mu.Lock() + s.mask = mask + s.mu.Unlock() +} + +// Read implements fs.FileOperations.Read. +func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { + // Attempt to dequeue relevant signals. + info, err := s.target.Sigtimedwait(s.Mask(), 0) + if err != nil { + // There must be no signal available. + return 0, syserror.ErrWouldBlock + } + + // Copy out the signal info using the specified format. + var buf [128]byte + binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + Signo: uint32(info.Signo), + Errno: info.Errno, + Code: info.Code, + PID: uint32(info.Pid()), + UID: uint32(info.Uid()), + Status: info.Status(), + Overrun: uint32(info.Overrun()), + Addr: info.Addr(), + }) + n, err := dst.CopyOut(ctx, buf[:]) + return int64(n), err +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + if mask&waiter.EventIn != 0 && s.target.PendingSignals()&s.Mask() != 0 { + return waiter.EventIn // Pending signals. + } + return 0 +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) { + // Register for the signal set; ignore the passed events. + s.target.SignalRegister(entry, waiter.EventMask(s.Mask())) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SignalOperations) EventUnregister(entry *waiter.Entry) { + // Unregister the original entry. + s.target.SignalUnregister(entry) +} diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index e91f82bb3..c82ef5486 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/third_party/gvsync" ) @@ -133,6 +134,13 @@ type Task struct { // signalStack is exclusive to the task goroutine. signalStack arch.SignalStack + // signalQueue is a set of registered waiters for signal-related events. + // + // signalQueue is protected by the signalMutex. Note that the task does + // not implement all queue methods, specifically the readiness checks. + // The task only broadcast a notification on signal delivery. + signalQueue waiter.Queue `state:"zerovalue"` + // If groupStopPending is true, the task should participate in a group // stop in the interrupt path. // diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 2a2e6f662..dd69939f9 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -15,6 +15,7 @@ package kernel import ( + "runtime" "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -121,6 +122,17 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { // Deactive our address space, we don't need it. interrupt := t.SleepStart() + // If the request is not completed, but the timer has already expired, + // then ensure that we run through a scheduler cycle. This is because + // we may see applications relying on timer slack to yield the thread. + // For example, they may attempt to sleep for some number of nanoseconds, + // and expect that this will actually yield the CPU and sleep for at + // least microseconds, e.g.: + // https://github.com/LMAX-Exchange/disruptor/commit/6ca210f2bcd23f703c479804d583718e16f43c07 + if len(timerChan) > 0 { + runtime.Gosched() + } + select { case <-C: t.SleepFinish(true) diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go index 78ff14b20..ce3e6ef28 100644 --- a/pkg/sentry/kernel/task_identity.go +++ b/pkg/sentry/kernel/task_identity.go @@ -465,8 +465,8 @@ func (t *Task) SetKeepCaps(k bool) { // disables the features we don't support anyway, is always set. This // drastically simplifies this function. // -// - We don't implement AT_SECURE, because no_new_privs always being set means -// that the conditions that require AT_SECURE never arise. (Compare Linux's +// - We don't set AT_SECURE = 1, because no_new_privs always being set means +// that the conditions that require AT_SECURE = 1 never arise. (Compare Linux's // security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().) // // - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index e76c069b0..8b148db35 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -126,12 +126,22 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) { t.gosched.Timestamp = now t.gosched.State = state t.goschedSeq.EndWrite() + + if state != TaskGoroutineRunningApp { + // Task is blocking/stopping. + t.k.decRunningTasks() + } } // Preconditions: The caller must be running on the task goroutine, and leaving // a state indicated by a previous call to // t.accountTaskGoroutineEnter(state). func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) { + if state != TaskGoroutineRunningApp { + // Task is unblocking/continuing. + t.k.incRunningTasks() + } + now := t.k.CPUClockNow() if t.gosched.State != state { panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys)) @@ -330,7 +340,7 @@ func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker { } // Notify implements ktime.TimerListener.Notify. -func (ticker *kernelCPUClockTicker) Notify(exp uint64) { +func (ticker *kernelCPUClockTicker) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { // Only increment cpuClock by 1 regardless of the number of expirations. // This approximately compensates for cases where thread throttling or bad // Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and @@ -426,6 +436,27 @@ func (ticker *kernelCPUClockTicker) Notify(exp uint64) { tgs[i] = nil } ticker.tgs = tgs[:0] + + // If nothing is running, we can disable the timer. + tasks := atomic.LoadInt64(&ticker.k.runningTasks) + if tasks == 0 { + ticker.k.runningTasksMu.Lock() + defer ticker.k.runningTasksMu.Unlock() + tasks := atomic.LoadInt64(&ticker.k.runningTasks) + if tasks != 0 { + // Raced with a 0 -> 1 transition. + return setting, false + } + + // Stop the timer. We must cache the current setting so the + // kernel can access it without violating the lock order. + ticker.k.cpuClockTickerSetting = setting + ticker.k.cpuClockTickerDisabled = true + setting.Enabled = false + return setting, true + } + + return setting, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 266959a07..39cd1340d 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -28,6 +28,7 @@ import ( ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" ) // SignalAction is an internal signal action. @@ -497,6 +498,9 @@ func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) { // // Preconditions: The signal mutex must be locked. func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool { + // Notify that the signal is queued. + t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig))) + // - Do not choose tasks that are blocking the signal. if linux.SignalSetOf(sig)&t.signalMask != 0 { return false @@ -1108,3 +1112,17 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { t.tg.signalHandlers.mu.Unlock() return t.deliverSignal(info, act) } + +// SignalRegister registers a waiter for pending signals. +func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) { + t.tg.signalHandlers.mu.Lock() + t.signalQueue.EventRegister(e, mask) + t.tg.signalHandlers.mu.Unlock() +} + +// SignalUnregister unregisters a waiter for pending signals. +func (t *Task) SignalUnregister(e *waiter.Entry) { + t.tg.signalHandlers.mu.Lock() + t.signalQueue.EventUnregister(e) + t.tg.signalHandlers.mu.Unlock() +} diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 0eef24bfb..72568d296 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -511,8 +511,9 @@ type itimerRealListener struct { } // Notify implements ktime.TimerListener.Notify. -func (l *itimerRealListener) Notify(exp uint64) { +func (l *itimerRealListener) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM)) + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index aa6c75d25..107394183 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -280,13 +280,16 @@ func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask { // A TimerListener receives expirations from a Timer. type TimerListener interface { // Notify is called when its associated Timer expires. exp is the number of - // expirations. + // expirations. setting is the next timer Setting. // // Notify is called with the associated Timer's mutex locked, so Notify // must not take any locks that precede Timer.mu in lock order. // + // If Notify returns true, the timer will use the returned setting + // rather than the passed one. + // // Preconditions: exp > 0. - Notify(exp uint64) + Notify(exp uint64, setting Setting) (newSetting Setting, update bool) // Destroy is called when the timer is destroyed. Destroy() @@ -533,7 +536,9 @@ func (t *Timer) Tick() { s, exp := t.setting.At(now) t.setting = s if exp > 0 { - t.listener.Notify(exp) + if newS, ok := t.listener.Notify(exp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) } @@ -588,7 +593,9 @@ func (t *Timer) Get() (Time, Setting) { s, exp := t.setting.At(now) t.setting = s if exp > 0 { - t.listener.Notify(exp) + if newS, ok := t.listener.Notify(exp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) return now, s @@ -620,7 +627,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) { } oldS, oldExp := t.setting.At(now) if oldExp > 0 { - t.listener.Notify(oldExp) + t.listener.Notify(oldExp, oldS) + // N.B. The returned Setting doesn't matter because we're about + // to overwrite. } if f != nil { f() @@ -628,7 +637,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) { newS, newExp := s.At(now) t.setting = newS if newExp > 0 { - t.listener.Notify(newExp) + if newS, ok := t.listener.Notify(newExp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) return now, oldS @@ -683,11 +694,13 @@ func NewChannelNotifier() (TimerListener, <-chan struct{}) { } // Notify implements ktime.TimerListener.Notify. -func (c *ChannelNotifier) Notify(uint64) { +func (c *ChannelNotifier) Notify(uint64, Setting) (Setting, bool) { select { case c.tchan <- struct{}{}: default: } + + return Setting{}, false } // Destroy implements ktime.TimerListener.Destroy and will close the channel. diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 40025d62d..59649c770 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "limits", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index bc5b841fb..2d9251e92 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -323,18 +323,22 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf. return syserror.ENOEXEC } + // N.B. Linux uses vm_brk_flags to map these pages, which only + // honors the X bit, always mapping at least RW. ignoring These + // pages are not included in the final brk region. + prot := usermem.ReadWrite + if phdr.Flags&elf.PF_X == elf.PF_X { + prot.Execute = true + } + if _, err := m.MMap(ctx, memmap.MMapOpts{ Length: uint64(anonSize), Addr: anonAddr, // Fixed without Unmap will fail the mmap if something is // already at addr. - Fixed: true, - Private: true, - // N.B. Linux uses vm_brk to map these pages, ignoring - // the segment protections, instead always mapping RW. - // These pages are not included in the final brk - // region. - Perms: usermem.ReadWrite, + Fixed: true, + Private: true, + Perms: prot, MaxPerms: usermem.AnyAccess, }); err != nil { ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err) @@ -464,7 +468,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el // base address big enough to fit all segments, so we first create a // mapping for the total size just to find a region that is big enough. // - // It is safe to unmap it immediately with racing with another mapping + // It is safe to unmap it immediately without racing with another mapping // because we are the only one in control of the MemoryManager. // // Note that the vaddr of the first PT_LOAD segment is ignored when diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index f6f1ae762..089d1635b 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -308,6 +308,9 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())}, arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())}, arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())}, + // The conditions that require AT_SECURE = 1 never arise. See + // kernel.Task.updateCredsForExecLocked. + arch.AuxEntry{linux.AT_SECURE, 0}, arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC}, arch.AuxEntry{linux.AT_EXECFN, execfn}, arch.AuxEntry{linux.AT_RANDOM, random}, diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 29c14ec56..9687e7e76 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "mappable_range", diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 072745a08..b35c8c673 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "file_refcount_set", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 858f895f2..3fd904c67 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "evictable_range", diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index eeb634644..b6d008dbe 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index ad8b95744..31fa48ec5 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -54,6 +55,7 @@ go_test( ], embed = [":kvm"], tags = [ + "manual", "nogotsan", "requires-kvm", ], diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD index 77a449a8b..b0e45f159 100644 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ b/pkg/sentry/platform/kvm/testutil/BUILD @@ -9,6 +9,8 @@ go_library( "testutil.go", "testutil_amd64.go", "testutil_amd64.s", + "testutil_arm64.go", + "testutil_arm64.s", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil", visibility = ["//pkg/sentry/platform/kvm:__pkg__"], diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go index 6cf2359a3..5c1efa0fd 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil.go +++ b/pkg/sentry/platform/kvm/testutil/testutil.go @@ -41,9 +41,6 @@ func TwiddleRegsFault() // TwiddleRegsSyscall twiddles registers then executes a syscall. func TwiddleRegsSyscall() -// TwiddleSegments reads segments into known registers. -func TwiddleSegments() - // FloatingPointWorks is a floating point test. // // It returns true or false. diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go index 203d71528..4c108abbf 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go @@ -21,6 +21,9 @@ import ( "syscall" ) +// TwiddleSegments reads segments into known registers. +func TwiddleSegments() + // SetTestTarget sets the rip appropriately. func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { regs.Rip = uint64(reflect.ValueOf(fn).Pointer()) diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go new file mode 100644 index 000000000..40b2e4acc --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -0,0 +1,59 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package testutil + +import ( + "fmt" + "reflect" + "syscall" +) + +// SetTestTarget sets the rip appropriately. +func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { + regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) +} + +// SetTouchTarget sets rax appropriately. +func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { + if target != nil { + regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer()) + } else { + regs.Regs[8] = 0 + } +} + +// RewindSyscall rewinds a syscall RIP. +func RewindSyscall(regs *syscall.PtraceRegs) { + regs.Pc -= 4 +} + +// SetTestRegs initializes registers to known values. +func SetTestRegs(regs *syscall.PtraceRegs) { + for i := 0; i <= 30; i++ { + regs.Regs[i] = uint64(i) + 1 + } +} + +// CheckTestRegs checks that registers were twiddled per TwiddleRegs. +func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { + for i := 0; i <= 30; i++ { + if need := ^uint64(i + 1); regs.Regs[i] != need { + err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need) + } + } + return +} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s new file mode 100644 index 000000000..2cd28b2d2 --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -0,0 +1,91 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +// test_util_arm64.s provides ARM64 test functions. + +#include "funcdata.h" +#include "textflag.h" + +#define SYS_GETPID 172 + +// This function simulates the getpid syscall. +TEXT ·Getpid(SB),NOSPLIT,$0 + NO_LOCAL_POINTERS + MOVD $SYS_GETPID, R8 + SVC + RET + +TEXT ·Touch(SB),NOSPLIT,$0 +start: + MOVD 0(R8), R1 + MOVD $SYS_GETPID, R8 // getpid + SVC + B start + +TEXT ·HaltLoop(SB),NOSPLIT,$0 +start: + HLT + B start + +// This function simulates a loop of syscall. +TEXT ·SyscallLoop(SB),NOSPLIT,$0 +start: + SVC + B start + +TEXT ·SpinLoop(SB),NOSPLIT,$0 +start: + B start + +// MVN: bitwise logical NOT +// This case simulates an application that modified R0-R30. +#define TWIDDLE_REGS() \ + MVN R0, R0; \ + MVN R1, R1; \ + MVN R2, R2; \ + MVN R3, R3; \ + MVN R4, R4; \ + MVN R5, R5; \ + MVN R6, R6; \ + MVN R7, R7; \ + MVN R8, R8; \ + MVN R9, R9; \ + MVN R10, R10; \ + MVN R11, R11; \ + MVN R12, R12; \ + MVN R13, R13; \ + MVN R14, R14; \ + MVN R15, R15; \ + MVN R16, R16; \ + MVN R17, R17; \ + MVN R18_PLATFORM, R18_PLATFORM; \ + MVN R19, R19; \ + MVN R20, R20; \ + MVN R21, R21; \ + MVN R22, R22; \ + MVN R23, R23; \ + MVN R24, R24; \ + MVN R25, R25; \ + MVN R26, R26; \ + MVN R27, R27; \ + MVN g, g; \ + MVN R29, R29; \ + MVN R30, R30; + +TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 + TWIDDLE_REGS() + SVC + RET // never reached diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 47957bb3b..72c7ec564 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -154,3 +154,19 @@ func (t *thread) clone() (*thread, error) { cpu: ^uint32(0), }, nil } + +// getEventMessage retrieves a message about the ptrace event that just happened. +func (t *thread) getEventMessage() (uintptr, error) { + var msg uintptr + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_GETEVENTMSG, + uintptr(t.tid), + 0, + uintptr(unsafe.Pointer(&msg)), + 0, 0) + if errno != 0 { + return msg, errno + } + return msg, nil +} diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 6bf7cd097..9f0ecfbe4 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -267,7 +267,7 @@ func (s *subprocess) newThread() *thread { // attach attaches to the thread. func (t *thread) attach() { - if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("unable to attach: %v", errno)) } @@ -355,7 +355,8 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal { } if stopSig == syscall.SIGTRAP { if status.TrapCause() == syscall.PTRACE_EVENT_EXIT { - t.dumpAndPanic("wait failed: the process exited") + msg, err := t.getEventMessage() + t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) } // Re-encode the trap cause the way it's expected. return stopSig | syscall.Signal(status.TrapCause()<<8) @@ -416,7 +417,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { for { // Execute the syscall instruction. - if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) } @@ -426,12 +427,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { break } else { // Some other signal caused a thread stop; ignore. + if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD { + log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig) + } continue } } // Complete the actual system call. - if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) } @@ -522,17 +526,17 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { for { // Start running until the next system call. if isSingleStepping(regs) { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall6( syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU_SINGLESTEP, - uintptr(t.tid), 0); errno != 0 { + uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) } } else { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall6( syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, - uintptr(t.tid), 0); errno != 0 { + uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) } } diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index f09b0b3d0..c075b5f91 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -53,7 +53,7 @@ func probeSeccomp() bool { for { // Attempt an emulation. - if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) } @@ -266,7 +266,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // Enable cpuid-faulting; this may fail on older kernels or hardware, // so we just disregard the result. Host CPUID will be enabled. - syscall.RawSyscall(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0) + syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0) // Call the stub; should not return. stubCall(stubStart, ppid) diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 3b95af617..ea090b686 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD index 924d8a6d6..6769cd0a5 100644 --- a/pkg/sentry/platform/safecopy/BUILD +++ b/pkg/sentry/platform/safecopy/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD index fd6dc8e6e..884020f7b 100644 --- a/pkg/sentry/safemem/BUILD +++ b/pkg/sentry/safemem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index eace3766d..c303435d5 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -23,7 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -// TODO(b/34161764): Move to pkg/abi/linux along with definitions in +// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in // pkg/sentry/arch. type sigaction struct { handler uintptr diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 45ebb2a0e..7da68384e 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", + "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 9e2e12799..445080aa4 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "port", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d0aab293d..b2732ca29 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -416,6 +417,24 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have Peek: flags&linux.MSG_PEEK != 0, } + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. However, the memory manager for the destination will not read + // the endpoint if the destination is zero length. + // + // In order for the endpoint to be read when the destination size is zero, + // we must cause a read of the endpoint by using a separate fake zero + // length block sequence and calling the EndpointReader directly. + if trunc && dst.Addrs.NumBytes() == 0 { + // Perform a read to a zero byte block sequence. We can ignore the + // original destination since it was zero bytes. The length returned by + // ReadToBlocks is ignored and we return the full message length to comply + // with MSG_TRUNC. + _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0)))) + return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + } + if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { @@ -499,6 +518,9 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error PortID: uint32(ms.PortID), }) + // Add the dump_done_errno payload. + m.Put(int64(0)) + _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/netstack/BUILD index e927821e1..60523f79a 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -3,15 +3,15 @@ package(licenses = ["notice"]) load("//tools/go_stateify:defs.bzl", "go_library") go_library( - name = "epsocket", + name = "netstack", srcs = [ "device.go", - "epsocket.go", + "netstack.go", "provider.go", "save_restore.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/epsocket", + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack", visibility = [ "//pkg/sentry:internal", ], diff --git a/pkg/sentry/socket/epsocket/device.go b/pkg/sentry/socket/netstack/device.go index 85484d5b1..fbeb89fb8 100644 --- a/pkg/sentry/socket/epsocket/device.go +++ b/pkg/sentry/socket/netstack/device.go @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import "gvisor.dev/gvisor/pkg/sentry/device" -// epsocketDevice is the endpoint socket virtual device. -var epsocketDevice = device.NewAnonDevice() +// netstackDevice is the endpoint socket virtual device. +var netstackDevice = device.NewAnonDevice() diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/netstack/netstack.go index 635042263..0ae573b45 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package epsocket provides an implementation of the socket.Socket interface +// Package netstack provides an implementation of the socket.Socket interface // that is backed by a tcpip.Endpoint. // // It does not depend on any particular endpoint implementation, and thus can @@ -22,17 +22,20 @@ // Lock ordering: netstack => mm: ioSequencePayload copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. -package epsocket +package netstack import ( "bytes" + "io" "math" + "reflect" "sync" "syscall" "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -52,6 +55,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -133,11 +137,13 @@ var Metrics = tcpip.Stats{ }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), @@ -151,6 +157,7 @@ var Metrics = tcpip.Stats{ ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."), InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."), SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."), + SegmentSendErrors: mustCreateMetric("/netstack/tcp/segment_send_errors", "Number of TCP segments failed to be sent."), ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."), ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."), Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."), @@ -166,13 +173,18 @@ var Metrics = tcpip.Stats{ UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."), ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."), MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), - PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent via sendUDP."), + PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), + PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), }, } +// DefaultTTL is linux's default TTL. All network protocols in all stacks used +// with this package must have this value set as their default TTL. +const DefaultTTL = 64 + const sizeOfInt32 int = 4 -var errStackType = syserr.New("expected but did not receive an epsocket.Stack", linux.EINVAL) +var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) // ntohs converts a 16-bit number from network byte order to host byte order. It // assumes that the host is little endian. @@ -205,6 +217,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and + // transport.Endpoint.SetSockOptInt. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error @@ -224,7 +240,6 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout *waiter.Queue @@ -255,8 +270,8 @@ type SocketOperations struct { // valid when timestampValid is true. It is protected by readMu. timestampNS int64 - // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket - // level, because it takes into account data from readView. + // sockOptInq corresponds to TCP_INQ. It is implemented at this level + // because it takes into account data from readView. sockOptInq bool } @@ -268,7 +283,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue } } - dirent := socket.NewDirent(t, epsocketDevice) + dirent := socket.NewDirent(t, netstackDevice) defer dirent.DecRef() return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ Queue: queue, @@ -409,17 +424,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return int64(n), nil } -// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand -// based on the requested size. +// WriteTo implements fs.FileOperations.WriteTo. +func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { + s.readMu.Lock() + + // Copy as much data as possible. + done := int64(0) + for count > 0 { + // This may return a blocking error. + if err := s.fetchReadView(); err != nil { + s.readMu.Unlock() + return done, err.ToError() + } + + // Write to the underlying file. + n, err := dst.Write(s.readView) + done += int64(n) + count -= int64(n) + if dup { + // That's all we support for dup. This is generally + // supported by any Linux system calls, but the + // expectation is that now a caller will call read to + // actually remove these bytes from the socket. + break + } + + // Drop that part of the view. + s.readView.TrimFront(n) + if err != nil { + s.readMu.Unlock() + return done, err + } + } + + s.readMu.Unlock() + return done, nil +} + +// ioSequencePayload implements tcpip.Payload. +// +// t copies user memory bytes on demand based on the requested size. type ioSequencePayload struct { ctx context.Context src usermem.IOSequence } -// Get implements tcpip.Payload. -func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { - if size > i.Size() { - size = i.Size() +// FullPayload implements tcpip.Payloader.FullPayload +func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { + return i.Payload(int(i.src.NumBytes())) +} + +// Payload implements tcpip.Payloader.Payload. +func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { + if max := int(i.src.NumBytes()); size > max { + size = max } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { @@ -428,11 +486,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { return v, nil } -// Size implements tcpip.Payload. -func (i *ioSequencePayload) Size() int { - return int(i.src.NumBytes()) -} - // DropFirst drops the first n bytes from underlying src. func (i *ioSequencePayload) DropFirst(n int) { i.src = i.src.DropFirst(int(n)) @@ -466,6 +519,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), nil } +// readerPayload implements tcpip.Payloader. +// +// It allocates a view and reads from a reader on-demand, based on available +// capacity in the endpoint. +type readerPayload struct { + ctx context.Context + r io.Reader + count int64 + err error +} + +// FullPayload implements tcpip.Payloader.FullPayload. +func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { + return r.Payload(int(r.count)) +} + +// Payload implements tcpip.Payloader.Payload. +func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { + if size > int(r.count) { + size = int(r.count) + } + v := buffer.NewView(size) + n, err := r.r.Read(v) + if n > 0 { + // We ignore the error here. It may re-occur on subsequent + // reads, but for now we can enqueue some amount of data. + r.count -= int64(n) + return v[:n], nil + } + if err == syserror.ErrWouldBlock { + return nil, tcpip.ErrWouldBlock + } else if err != nil { + r.err = err // Save for propation. + return nil, tcpip.ErrBadAddress + } + + // There is no data and no error. Return an error, which will propagate + // r.err, which will be nil. This is the desired result: (0, nil). + return nil, tcpip.ErrBadAddress +} + +// ReadFrom implements fs.FileOperations.ReadFrom. +func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + f := &readerPayload{ctx: ctx, r: r, count: count} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ + Atomic: true, // See above. + }) + } + if err == tcpip.ErrWouldBlock { + return n, syserror.ErrWouldBlock + } else if err != nil { + return int64(n), f.err // Propagate error. + } + + return int64(n), nil +} + // Readiness returns a mask of ready events for socket s. func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) @@ -643,7 +768,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // tcpip.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is - // implemented specifically for epsocket.SocketOperations rather than + // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket // options where the implementation is not shared, as unix sockets need // their own support for SO_TIMESTAMP. @@ -716,7 +841,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return getSockOptIPv6(t, ep, name, outLen) case linux.SOL_IP: - return getSockOptIP(t, ep, name, outLen) + return getSockOptIP(t, ep, name, outLen, family) case linux.SOL_UDP, linux.SOL_ICMPV6, @@ -774,8 +899,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -790,8 +915,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -825,6 +950,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BINDTODEVICE: + var v tcpip.BindToDeviceOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if len(v) == 0 { + return []byte{}, nil + } + if outLen < linux.IFNAMSIZ { + return nil, syserr.ErrInvalidArgument + } + return append([]byte(v), 0), nil + case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1039,6 +1177,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_TCLASS: + // Length handling for parity with Linux. + if outLen == 0 { + return make([]byte, 0), nil + } + var v tcpip.IPv6TrafficClassOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + uintv := uint32(v) + // Linux truncates the output binary to outLen. + ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + // Handle cases where outLen is lesser than sizeOfInt32. + if len(ib) > outLen { + ib = ib[:outLen] + } + return ib, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1046,8 +1203,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { switch name { + case linux.IP_TTL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TTLOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + // Fill in the default value, if needed. + if v == 0 { + v = DefaultTTL + } + + return int32(v), nil + case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1089,6 +1263,20 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac } return int32(0), nil + case linux.IP_TOS: + // Length handling for parity with Linux. + if outLen == 0 { + return []byte(nil), nil + } + var v tcpip.IPv4TOSOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if outLen < sizeOfInt32 { + return uint8(v), nil + } + return int32(v), nil + default: emitUnimplementedEventIP(t, name) } @@ -1099,7 +1287,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac // tcpip.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is - // implemented specifically for epsocket.SocketOperations rather than + // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket // options where the implementation is not shared, as unix sockets need // their own support for SO_TIMESTAMP. @@ -1162,7 +1350,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1170,7 +1358,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -1188,6 +1376,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BINDTODEVICE: + n := bytes.IndexByte(optVal, 0) + if n == -1 { + n = len(optVal) + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -1380,6 +1575,19 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_TCLASS: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < -1 || v > 255 { + return syserr.ErrInvalidArgument + } + if v == -1 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + default: emitUnimplementedEventIPv6(t, name) } @@ -1511,6 +1719,30 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s t.Kernel().EmitUnimplementedEvent(t) return syserr.ErrInvalidArgument + case linux.IP_TTL: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + // -1 means default TTL. + if v == -1 { + v = 0 + } else if v < 1 || v > 255 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + + case linux.IP_TOS: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1534,9 +1766,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, - linux.IP_TOS, linux.IP_TRANSPARENT, - linux.IP_TTL, linux.IP_UNBLOCK_SOURCE, linux.IP_UNICAST_IF, linux.IP_XFRM_POLICY, @@ -2057,7 +2287,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= int64(v.Size()) || dontWait) { + if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. return int(n), nil } @@ -2082,7 +2312,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.TranslateNetstackError(err) } - if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } @@ -2098,10 +2328,11 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. - if int(args[1].Int()) == syscall.SIOCGSTAMP { + switch args[1].Int() { + case syscall.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2113,6 +2344,25 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, AddressSpaceActive: true, }) return 0, err + + case linux.TIOCINQ: + v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() + } + + // Add bytes removed from the endpoint but not yet sent to the caller. + v += len(s.readView) + + if v > math.MaxInt32 { + v = math.MaxInt32 + } + + // Copy result to user-space. + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err } return Ioctl(ctx, s.Endpoint, io, args) @@ -2184,9 +2434,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCOUTQ: - var v tcpip.SendQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { @@ -2381,7 +2631,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Flag values and meanings are described in greater detail in netdevice(7) in // the SIOCGIFFLAGS section. func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) { - // epsocket should only ever be passed an epsocket.Stack. + // We should only ever be passed a netstack.Stack. epstack, ok := stack.(*Stack) if !ok { return 0, errStackType @@ -2421,7 +2671,8 @@ func (s *SocketOperations) State() uint32 { return 0 } - if !s.isPacketBased() { + switch { + case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -2450,9 +2701,26 @@ func (s *SocketOperations) State() uint32 { // Internal or unknown state. return 0 } + case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + // UDP socket. + switch udp.EndpointState(s.Endpoint.State()) { + case udp.StateInitial, udp.StateBound, udp.StateClosed: + return linux.TCP_CLOSE + case udp.StateConnected: + return linux.TCP_ESTABLISHED + default: + return 0 + } + case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + // TODO(b/112063468): Export states for ICMP sockets. + case s.skType == linux.SOCK_RAW: + // TODO(b/112063468): Export states for raw sockets. + default: + // Unknown transport protocol, how did we make this socket? + log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem()) + return 0 } - // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/netstack/provider.go index 421f93dc4..357a664cc 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "syscall" @@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { - return 0, true, syserr.ErrPermissionDenied + return 0, true, syserr.ErrNotPermitted } switch protocol { diff --git a/pkg/sentry/socket/epsocket/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go index f7b8c10cc..c7aaf722a 100644 --- a/pkg/sentry/socket/epsocket/save_restore.go +++ b/pkg/sentry/socket/netstack/save_restore.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/netstack/stack.go index 7cf7ff735..fda0156e5 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "gvisor.dev/gvisor/pkg/abi/linux" diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 5061dcbde..3a6baa308 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -49,6 +50,14 @@ proto_library( ], ) +cc_proto_library( + name = "syscall_rpc_cc_proto", + visibility = [ + "//visibility:public", + ], + deps = [":syscall_rpc_proto"], +) + go_proto_library( name = "syscall_rpc_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto", diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index da9977fde..830f4da10 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -24,7 +24,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", + "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index 1c71609e2..e27b1c714 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -161,7 +161,8 @@ func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) { if q.dataList.Front() == nil { err := syserr.ErrWouldBlock if q.closed { - if err = syserr.ErrClosedForReceive; q.unread { + err = syserr.ErrClosedForReceive + if q.unread { err = syserr.ErrConnectionReset } } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index a103b1aab..529a7a7a9 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has + // the int type. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error @@ -847,6 +851,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -862,65 +870,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrQueueSizeNotSupported } return v, nil - default: - return -1, tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - case *tcpip.SendQueueSizeOption: + case tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v := e.connected.SendQueuedSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported - } - *o = qs - return nil - - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - return nil + return int(v), nil - case *tcpip.SendBufferSizeOption: + case tcpip.SendBufferSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + v := e.connected.SendMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs - return nil + return int(v), nil - case *tcpip.ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + v := e.receiver.RecvMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return int(v), nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) } - *o = qs return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 8e4f06a22..1aaae8487 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -31,7 +31,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" @@ -40,8 +40,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// SocketOperations is a Unix socket. It is similar to an epsocket, except it -// is backed by a transport.Endpoint instead of a tcpip.Endpoint. +// SocketOperations is a Unix socket. It is similar to a netstack socket, +// except it is backed by a transport.Endpoint instead of a tcpip.Endpoint. // // +stateify savable type SocketOperations struct { @@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { return "", err } @@ -143,7 +143,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, return nil, 0, syserr.TranslateNetstackError(err) } - a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr) + a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -155,19 +155,19 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, return nil, 0, syserr.TranslateNetstackError(err) } - a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr) + a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - return epsocket.Ioctl(ctx, s.ep, io, args) + return netstack.Ioctl(ctx, s.ep, io, args) } // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { - return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by @@ -474,13 +474,13 @@ func (s *SocketOperations) EventUnregister(e *waiter.Entry) { // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { - return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal) + return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { - f, err := epsocket.ConvertShutdown(how) + f, err := netstack.ConvertShutdown(how) if err != nil { return err } @@ -546,7 +546,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -581,7 +581,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 445d25010..72ebf766d 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -31,8 +32,8 @@ go_library( "//pkg/sentry/arch", "//pkg/sentry/kernel", "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", "//pkg/sentry/socket/netlink", + "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", "//pkg/sentry/usermem", ], @@ -44,6 +45,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "strace_cc_proto", + visibility = ["//visibility:public"], + deps = [":strace_proto"], +) + go_proto_library( name = "strace_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto", diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go index 3650fd6e1..5d57b75af 100644 --- a/pkg/sentry/strace/linux64.go +++ b/pkg/sentry/strace/linux64.go @@ -335,4 +335,5 @@ var linuxAMD64 = SyscallMap{ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex), 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex), 317: makeSyscallInfo("seccomp", Hex, Hex, Hex), + 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex), } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f779186ad..94334f6d2 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/control" - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/usermem" ) @@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */) + fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 33a40b9c6..cf2a56bed 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -8,6 +8,8 @@ go_library( "error.go", "flags.go", "linux64.go", + "linux64_amd64.go", + "linux64_arm64.go", "sigset.go", "sys_aio.go", "sys_capability.go", @@ -74,6 +76,7 @@ go_library( "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/sched", "//pkg/sentry/kernel/shm", + "//pkg/sentry/kernel/signalfd", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/memmap", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ed996ba51..b64c49ff5 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,377 +15,8 @@ // Package linux provides syscall tables for amd64 Linux. package linux -import ( - "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/syscalls" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" +const ( + _LINUX_SYSNAME = "Linux" + _LINUX_RELEASE = "4.4" + _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016" ) - -// AUDIT_ARCH_X86_64 identifies the Linux syscall API on AMD64, and is taken -// from <linux/audit.h>. -const _AUDIT_ARCH_X86_64 = 0xc000003e - -// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall -// numbers from Linux 4.4. -var AMD64 = &kernel.SyscallTable{ - OS: abi.Linux, - Arch: arch.AMD64, - Version: kernel.Version{ - // Version 4.4 is chosen as a stable, longterm version of Linux, which - // guides the interface provided by this syscall table. The build - // version is that for a clean build with default kernel config, at 5 - // minutes after v4.4 was tagged. - Sysname: "Linux", - Release: "4.4", - Version: "#1 SMP Sun Jan 10 15:06:54 PST 2016", - }, - AuditNumber: _AUDIT_ARCH_X86_64, - Table: map[uintptr]kernel.Syscall{ - 0: syscalls.Supported("read", Read), - 1: syscalls.Supported("write", Write), - 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil), - 3: syscalls.Supported("close", Close), - 4: syscalls.Supported("stat", Stat), - 5: syscalls.Supported("fstat", Fstat), - 6: syscalls.Supported("lstat", Lstat), - 7: syscalls.Supported("poll", Poll), - 8: syscalls.Supported("lseek", Lseek), - 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil), - 10: syscalls.Supported("mprotect", Mprotect), - 11: syscalls.Supported("munmap", Munmap), - 12: syscalls.Supported("brk", Brk), - 13: syscalls.Supported("rt_sigaction", RtSigaction), - 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask), - 15: syscalls.Supported("rt_sigreturn", RtSigreturn), - 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), - 17: syscalls.Supported("pread64", Pread64), - 18: syscalls.Supported("pwrite64", Pwrite64), - 19: syscalls.Supported("readv", Readv), - 20: syscalls.Supported("writev", Writev), - 21: syscalls.Supported("access", Access), - 22: syscalls.Supported("pipe", Pipe), - 23: syscalls.Supported("select", Select), - 24: syscalls.Supported("sched_yield", SchedYield), - 25: syscalls.Supported("mremap", Mremap), - 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), - 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), - 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), - 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), - 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), - 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), - 32: syscalls.Supported("dup", Dup), - 33: syscalls.Supported("dup2", Dup2), - 34: syscalls.Supported("pause", Pause), - 35: syscalls.Supported("nanosleep", Nanosleep), - 36: syscalls.Supported("getitimer", Getitimer), - 37: syscalls.Supported("alarm", Alarm), - 38: syscalls.Supported("setitimer", Setitimer), - 39: syscalls.Supported("getpid", Getpid), - 40: syscalls.Supported("sendfile", Sendfile), - 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), - 42: syscalls.Supported("connect", Connect), - 43: syscalls.Supported("accept", Accept), - 44: syscalls.Supported("sendto", SendTo), - 45: syscalls.Supported("recvfrom", RecvFrom), - 46: syscalls.Supported("sendmsg", SendMsg), - 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), - 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), - 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), - 50: syscalls.Supported("listen", Listen), - 51: syscalls.Supported("getsockname", GetSockName), - 52: syscalls.Supported("getpeername", GetPeerName), - 53: syscalls.Supported("socketpair", SocketPair), - 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), - 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), - 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), - 57: syscalls.Supported("fork", Fork), - 58: syscalls.Supported("vfork", Vfork), - 59: syscalls.Supported("execve", Execve), - 60: syscalls.Supported("exit", Exit), - 61: syscalls.Supported("wait4", Wait4), - 62: syscalls.Supported("kill", Kill), - 63: syscalls.Supported("uname", Uname), - 64: syscalls.Supported("semget", Semget), - 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), - 67: syscalls.Supported("shmdt", Shmdt), - 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), - 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), - 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), - 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), - 76: syscalls.Supported("truncate", Truncate), - 77: syscalls.Supported("ftruncate", Ftruncate), - 78: syscalls.Supported("getdents", Getdents), - 79: syscalls.Supported("getcwd", Getcwd), - 80: syscalls.Supported("chdir", Chdir), - 81: syscalls.Supported("fchdir", Fchdir), - 82: syscalls.Supported("rename", Rename), - 83: syscalls.Supported("mkdir", Mkdir), - 84: syscalls.Supported("rmdir", Rmdir), - 85: syscalls.Supported("creat", Creat), - 86: syscalls.Supported("link", Link), - 87: syscalls.Supported("unlink", Unlink), - 88: syscalls.Supported("symlink", Symlink), - 89: syscalls.Supported("readlink", Readlink), - 90: syscalls.Supported("chmod", Chmod), - 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), - 92: syscalls.Supported("chown", Chown), - 93: syscalls.Supported("fchown", Fchown), - 94: syscalls.Supported("lchown", Lchown), - 95: syscalls.Supported("umask", Umask), - 96: syscalls.Supported("gettimeofday", Gettimeofday), - 97: syscalls.Supported("getrlimit", Getrlimit), - 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), - 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), - 100: syscalls.Supported("times", Times), - 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), - 102: syscalls.Supported("getuid", Getuid), - 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), - 104: syscalls.Supported("getgid", Getgid), - 105: syscalls.Supported("setuid", Setuid), - 106: syscalls.Supported("setgid", Setgid), - 107: syscalls.Supported("geteuid", Geteuid), - 108: syscalls.Supported("getegid", Getegid), - 109: syscalls.Supported("setpgid", Setpgid), - 110: syscalls.Supported("getppid", Getppid), - 111: syscalls.Supported("getpgrp", Getpgrp), - 112: syscalls.Supported("setsid", Setsid), - 113: syscalls.Supported("setreuid", Setreuid), - 114: syscalls.Supported("setregid", Setregid), - 115: syscalls.Supported("getgroups", Getgroups), - 116: syscalls.Supported("setgroups", Setgroups), - 117: syscalls.Supported("setresuid", Setresuid), - 118: syscalls.Supported("getresuid", Getresuid), - 119: syscalls.Supported("setresgid", Setresgid), - 120: syscalls.Supported("getresgid", Getresgid), - 121: syscalls.Supported("getpgid", Getpgid), - 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) - 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) - 124: syscalls.Supported("getsid", Getsid), - 125: syscalls.Supported("capget", Capget), - 126: syscalls.Supported("capset", Capset), - 127: syscalls.Supported("rt_sigpending", RtSigpending), - 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), - 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), - 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend), - 131: syscalls.Supported("sigaltstack", Sigaltstack), - 132: syscalls.Supported("utime", Utime), - 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil), - 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil), - 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), - 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil), - 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), - 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), - 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}), - 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), - 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), - 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), - 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), - 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), - 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), - 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), - 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), - 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), - 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), - 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil), - 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil), - 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil), - 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), - 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil), - 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), - 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), - 161: syscalls.Supported("chroot", Chroot), - 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), - 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), - 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), - 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), - 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), - 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), - 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), - 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), - 170: syscalls.Supported("sethostname", Sethostname), - 171: syscalls.Supported("setdomainname", Setdomainname), - 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil), - 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil), - 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil), - 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), - 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), - 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), - 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), - 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations - 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), - 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), - 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), - 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil), - 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), - 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), - 186: syscalls.Supported("gettid", Gettid), - 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341) - 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 200: syscalls.Supported("tkill", Tkill), - 201: syscalls.Supported("time", Time), - 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), - 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), - 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), - 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), - 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), - 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), - 213: syscalls.Supported("epoll_create", EpollCreate), - 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil), - 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil), - 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), - 217: syscalls.Supported("getdents64", Getdents64), - 218: syscalls.Supported("set_tid_address", SetTidAddress), - 219: syscalls.Supported("restart_syscall", RestartSyscall), - 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) - 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), - 222: syscalls.Supported("timer_create", TimerCreate), - 223: syscalls.Supported("timer_settime", TimerSettime), - 224: syscalls.Supported("timer_gettime", TimerGettime), - 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun), - 226: syscalls.Supported("timer_delete", TimerDelete), - 227: syscalls.Supported("clock_settime", ClockSettime), - 228: syscalls.Supported("clock_gettime", ClockGettime), - 229: syscalls.Supported("clock_getres", ClockGetres), - 230: syscalls.Supported("clock_nanosleep", ClockNanosleep), - 231: syscalls.Supported("exit_group", ExitGroup), - 232: syscalls.Supported("epoll_wait", EpollWait), - 233: syscalls.Supported("epoll_ctl", EpollCtl), - 234: syscalls.Supported("tgkill", Tgkill), - 235: syscalls.Supported("utimes", Utimes), - 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil), - 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), - 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), - 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), - 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), - 247: syscalls.Supported("waitid", Waitid), - 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), - 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), - 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), - 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) - 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) - 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil), - 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), - 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), - 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), - 257: syscalls.Supported("openat", Openat), - 258: syscalls.Supported("mkdirat", Mkdirat), - 259: syscalls.Supported("mknodat", Mknodat), - 260: syscalls.Supported("fchownat", Fchownat), - 261: syscalls.Supported("futimesat", Futimesat), - 262: syscalls.Supported("fstatat", Fstatat), - 263: syscalls.Supported("unlinkat", Unlinkat), - 264: syscalls.Supported("renameat", Renameat), - 265: syscalls.Supported("linkat", Linkat), - 266: syscalls.Supported("symlinkat", Symlinkat), - 267: syscalls.Supported("readlinkat", Readlinkat), - 268: syscalls.Supported("fchmodat", Fchmodat), - 269: syscalls.Supported("faccessat", Faccessat), - 270: syscalls.Supported("pselect", Pselect), - 271: syscalls.Supported("ppoll", Ppoll), - 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), - 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), - 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) - 280: syscalls.Supported("utimensat", Utimensat), - 281: syscalls.Supported("epoll_pwait", EpollPwait), - 282: syscalls.ErrorWithEvent("signalfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) - 283: syscalls.Supported("timerfd_create", TimerfdCreate), - 284: syscalls.Supported("eventfd", Eventfd), - 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), - 286: syscalls.Supported("timerfd_settime", TimerfdSettime), - 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), - 288: syscalls.Supported("accept4", Accept4), - 289: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) - 290: syscalls.Supported("eventfd2", Eventfd2), - 291: syscalls.Supported("epoll_create1", EpollCreate1), - 292: syscalls.Supported("dup3", Dup3), - 293: syscalls.Supported("pipe2", Pipe2), - 294: syscalls.Supported("inotify_init1", InotifyInit1), - 295: syscalls.Supported("preadv", Preadv), - 296: syscalls.Supported("pwritev", Pwritev), - 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), - 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), - 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), - 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), - 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), - 302: syscalls.Supported("prlimit64", Prlimit64), - 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), - 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), - 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), - 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), - 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), - 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) - 309: syscalls.Supported("getcpu", Getcpu), - 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), - 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), - 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), - 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), - 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) - 317: syscalls.Supported("seccomp", Seccomp), - 318: syscalls.Supported("getrandom", GetRandom), - 319: syscalls.Supported("memfd_create", MemfdCreate), - 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), - 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), - 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) - 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) - 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - - // Syscalls after 325 are "backports" from versions of Linux after 4.4. - 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), - 327: syscalls.Supported("preadv2", Preadv2), - 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), - 332: syscalls.Supported("statx", Statx), - }, - - Emulate: map[usermem.Addr]uintptr{ - 0xffffffffff600000: 96, // vsyscall gettimeofday(2) - 0xffffffffff600400: 201, // vsyscall time(2) - 0xffffffffff600800: 309, // vsyscall getcpu(2) - }, - Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { - t.Kernel().EmitUnimplementedEvent(t) - return 0, syserror.ENOSYS - }, -} diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go new file mode 100644 index 000000000..e215ac049 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -0,0 +1,386 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall +// numbers from Linux 4.4. +var AMD64 = &kernel.SyscallTable{ + OS: abi.Linux, + Arch: arch.AMD64, + Version: kernel.Version{ + // Version 4.4 is chosen as a stable, longterm version of Linux, which + // guides the interface provided by this syscall table. The build + // version is that for a clean build with default kernel config, at 5 + // minutes after v4.4 was tagged. + Sysname: _LINUX_SYSNAME, + Release: _LINUX_RELEASE, + Version: _LINUX_VERSION, + }, + AuditNumber: linux.AUDIT_ARCH_X86_64, + Table: map[uintptr]kernel.Syscall{ + 0: syscalls.Supported("read", Read), + 1: syscalls.Supported("write", Write), + 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil), + 3: syscalls.Supported("close", Close), + 4: syscalls.Supported("stat", Stat), + 5: syscalls.Supported("fstat", Fstat), + 6: syscalls.Supported("lstat", Lstat), + 7: syscalls.Supported("poll", Poll), + 8: syscalls.Supported("lseek", Lseek), + 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil), + 10: syscalls.Supported("mprotect", Mprotect), + 11: syscalls.Supported("munmap", Munmap), + 12: syscalls.Supported("brk", Brk), + 13: syscalls.Supported("rt_sigaction", RtSigaction), + 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask), + 15: syscalls.Supported("rt_sigreturn", RtSigreturn), + 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), + 17: syscalls.Supported("pread64", Pread64), + 18: syscalls.Supported("pwrite64", Pwrite64), + 19: syscalls.Supported("readv", Readv), + 20: syscalls.Supported("writev", Writev), + 21: syscalls.Supported("access", Access), + 22: syscalls.Supported("pipe", Pipe), + 23: syscalls.Supported("select", Select), + 24: syscalls.Supported("sched_yield", SchedYield), + 25: syscalls.Supported("mremap", Mremap), + 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), + 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), + 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), + 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), + 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), + 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), + 32: syscalls.Supported("dup", Dup), + 33: syscalls.Supported("dup2", Dup2), + 34: syscalls.Supported("pause", Pause), + 35: syscalls.Supported("nanosleep", Nanosleep), + 36: syscalls.Supported("getitimer", Getitimer), + 37: syscalls.Supported("alarm", Alarm), + 38: syscalls.Supported("setitimer", Setitimer), + 39: syscalls.Supported("getpid", Getpid), + 40: syscalls.Supported("sendfile", Sendfile), + 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), + 42: syscalls.Supported("connect", Connect), + 43: syscalls.Supported("accept", Accept), + 44: syscalls.Supported("sendto", SendTo), + 45: syscalls.Supported("recvfrom", RecvFrom), + 46: syscalls.Supported("sendmsg", SendMsg), + 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), + 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), + 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), + 50: syscalls.Supported("listen", Listen), + 51: syscalls.Supported("getsockname", GetSockName), + 52: syscalls.Supported("getpeername", GetPeerName), + 53: syscalls.Supported("socketpair", SocketPair), + 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), + 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), + 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), + 57: syscalls.Supported("fork", Fork), + 58: syscalls.Supported("vfork", Vfork), + 59: syscalls.Supported("execve", Execve), + 60: syscalls.Supported("exit", Exit), + 61: syscalls.Supported("wait4", Wait4), + 62: syscalls.Supported("kill", Kill), + 63: syscalls.Supported("uname", Uname), + 64: syscalls.Supported("semget", Semget), + 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 67: syscalls.Supported("shmdt", Shmdt), + 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), + 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), + 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), + 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), + 76: syscalls.Supported("truncate", Truncate), + 77: syscalls.Supported("ftruncate", Ftruncate), + 78: syscalls.Supported("getdents", Getdents), + 79: syscalls.Supported("getcwd", Getcwd), + 80: syscalls.Supported("chdir", Chdir), + 81: syscalls.Supported("fchdir", Fchdir), + 82: syscalls.Supported("rename", Rename), + 83: syscalls.Supported("mkdir", Mkdir), + 84: syscalls.Supported("rmdir", Rmdir), + 85: syscalls.Supported("creat", Creat), + 86: syscalls.Supported("link", Link), + 87: syscalls.Supported("unlink", Unlink), + 88: syscalls.Supported("symlink", Symlink), + 89: syscalls.Supported("readlink", Readlink), + 90: syscalls.Supported("chmod", Chmod), + 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), + 92: syscalls.Supported("chown", Chown), + 93: syscalls.Supported("fchown", Fchown), + 94: syscalls.Supported("lchown", Lchown), + 95: syscalls.Supported("umask", Umask), + 96: syscalls.Supported("gettimeofday", Gettimeofday), + 97: syscalls.Supported("getrlimit", Getrlimit), + 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), + 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), + 100: syscalls.Supported("times", Times), + 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), + 102: syscalls.Supported("getuid", Getuid), + 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), + 104: syscalls.Supported("getgid", Getgid), + 105: syscalls.Supported("setuid", Setuid), + 106: syscalls.Supported("setgid", Setgid), + 107: syscalls.Supported("geteuid", Geteuid), + 108: syscalls.Supported("getegid", Getegid), + 109: syscalls.Supported("setpgid", Setpgid), + 110: syscalls.Supported("getppid", Getppid), + 111: syscalls.Supported("getpgrp", Getpgrp), + 112: syscalls.Supported("setsid", Setsid), + 113: syscalls.Supported("setreuid", Setreuid), + 114: syscalls.Supported("setregid", Setregid), + 115: syscalls.Supported("getgroups", Getgroups), + 116: syscalls.Supported("setgroups", Setgroups), + 117: syscalls.Supported("setresuid", Setresuid), + 118: syscalls.Supported("getresuid", Getresuid), + 119: syscalls.Supported("setresgid", Setresgid), + 120: syscalls.Supported("getresgid", Getresgid), + 121: syscalls.Supported("getpgid", Getpgid), + 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 124: syscalls.Supported("getsid", Getsid), + 125: syscalls.Supported("capget", Capget), + 126: syscalls.Supported("capset", Capset), + 127: syscalls.Supported("rt_sigpending", RtSigpending), + 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), + 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), + 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend), + 131: syscalls.Supported("sigaltstack", Sigaltstack), + 132: syscalls.Supported("utime", Utime), + 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil), + 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil), + 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), + 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil), + 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), + 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), + 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}), + 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), + 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), + 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), + 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), + 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), + 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), + 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), + 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), + 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), + 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), + 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil), + 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil), + 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil), + 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), + 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil), + 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), + 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), + 161: syscalls.Supported("chroot", Chroot), + 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), + 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), + 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), + 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), + 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), + 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), + 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), + 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), + 170: syscalls.Supported("sethostname", Sethostname), + 171: syscalls.Supported("setdomainname", Setdomainname), + 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil), + 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil), + 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil), + 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), + 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), + 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), + 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), + 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations + 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), + 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), + 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), + 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil), + 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), + 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), + 186: syscalls.Supported("gettid", Gettid), + 187: syscalls.Supported("readahead", Readahead), + 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 200: syscalls.Supported("tkill", Tkill), + 201: syscalls.Supported("time", Time), + 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), + 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), + 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), + 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), + 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), + 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), + 213: syscalls.Supported("epoll_create", EpollCreate), + 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil), + 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil), + 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), + 217: syscalls.Supported("getdents64", Getdents64), + 218: syscalls.Supported("set_tid_address", SetTidAddress), + 219: syscalls.Supported("restart_syscall", RestartSyscall), + 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), + 222: syscalls.Supported("timer_create", TimerCreate), + 223: syscalls.Supported("timer_settime", TimerSettime), + 224: syscalls.Supported("timer_gettime", TimerGettime), + 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun), + 226: syscalls.Supported("timer_delete", TimerDelete), + 227: syscalls.Supported("clock_settime", ClockSettime), + 228: syscalls.Supported("clock_gettime", ClockGettime), + 229: syscalls.Supported("clock_getres", ClockGetres), + 230: syscalls.Supported("clock_nanosleep", ClockNanosleep), + 231: syscalls.Supported("exit_group", ExitGroup), + 232: syscalls.Supported("epoll_wait", EpollWait), + 233: syscalls.Supported("epoll_ctl", EpollCtl), + 234: syscalls.Supported("tgkill", Tgkill), + 235: syscalls.Supported("utimes", Utimes), + 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil), + 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), + 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), + 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), + 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), + 247: syscalls.Supported("waitid", Waitid), + 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), + 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), + 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), + 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil), + 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), + 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), + 257: syscalls.Supported("openat", Openat), + 258: syscalls.Supported("mkdirat", Mkdirat), + 259: syscalls.Supported("mknodat", Mknodat), + 260: syscalls.Supported("fchownat", Fchownat), + 261: syscalls.Supported("futimesat", Futimesat), + 262: syscalls.Supported("fstatat", Fstatat), + 263: syscalls.Supported("unlinkat", Unlinkat), + 264: syscalls.Supported("renameat", Renameat), + 265: syscalls.Supported("linkat", Linkat), + 266: syscalls.Supported("symlinkat", Symlinkat), + 267: syscalls.Supported("readlinkat", Readlinkat), + 268: syscalls.Supported("fchmodat", Fchmodat), + 269: syscalls.Supported("faccessat", Faccessat), + 270: syscalls.Supported("pselect", Pselect), + 271: syscalls.Supported("ppoll", Ppoll), + 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), + 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), + 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), + 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) + 280: syscalls.Supported("utimensat", Utimensat), + 281: syscalls.Supported("epoll_pwait", EpollPwait), + 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), + 283: syscalls.Supported("timerfd_create", TimerfdCreate), + 284: syscalls.Supported("eventfd", Eventfd), + 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), + 286: syscalls.Supported("timerfd_settime", TimerfdSettime), + 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), + 288: syscalls.Supported("accept4", Accept4), + 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), + 290: syscalls.Supported("eventfd2", Eventfd2), + 291: syscalls.Supported("epoll_create1", EpollCreate1), + 292: syscalls.Supported("dup3", Dup3), + 293: syscalls.Supported("pipe2", Pipe2), + 294: syscalls.Supported("inotify_init1", InotifyInit1), + 295: syscalls.Supported("preadv", Preadv), + 296: syscalls.Supported("pwritev", Pwritev), + 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), + 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), + 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), + 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 302: syscalls.Supported("prlimit64", Prlimit64), + 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), + 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), + 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), + 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) + 309: syscalls.Supported("getcpu", Getcpu), + 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), + 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), + 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) + 317: syscalls.Supported("seccomp", Seccomp), + 318: syscalls.Supported("getrandom", GetRandom), + 319: syscalls.Supported("memfd_create", MemfdCreate), + 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), + 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), + 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) + 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) + 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) + 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + + // Syscalls after 325 are "backports" from versions of Linux after 4.4. + 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), + 327: syscalls.Supported("preadv2", Preadv2), + 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 332: syscalls.Supported("statx", Statx), + }, + + Emulate: map[usermem.Addr]uintptr{ + 0xffffffffff600000: 96, // vsyscall gettimeofday(2) + 0xffffffffff600400: 201, // vsyscall time(2) + 0xffffffffff600800: 309, // vsyscall getcpu(2) + }, + Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { + t.Kernel().EmitUnimplementedEvent(t) + return 0, syserror.ENOSYS + }, +} diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go new file mode 100644 index 000000000..1d3b63020 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -0,0 +1,313 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// ARM64 is a table of Linux arm64 syscall API with the corresponding syscall +// numbers from Linux 4.4. +var ARM64 = &kernel.SyscallTable{ + OS: abi.Linux, + Arch: arch.ARM64, + Version: kernel.Version{ + Sysname: _LINUX_SYSNAME, + Release: _LINUX_RELEASE, + Version: _LINUX_VERSION, + }, + AuditNumber: linux.AUDIT_ARCH_AARCH64, + Table: map[uintptr]kernel.Syscall{ + 0: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 1: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 5: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 8: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 12: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 13: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 14: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 15: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 16: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 17: syscalls.Supported("getcwd", Getcwd), + 18: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), + 19: syscalls.Supported("eventfd2", Eventfd2), + 20: syscalls.Supported("epoll_create1", EpollCreate1), + 21: syscalls.Supported("epoll_ctl", EpollCtl), + 22: syscalls.Supported("epoll_pwait", EpollPwait), + 23: syscalls.Supported("dup", Dup), + 24: syscalls.Supported("dup3", Dup3), + 26: syscalls.Supported("inotify_init1", InotifyInit1), + 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), + 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), + 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 32: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), + 33: syscalls.Supported("mknodat", Mknodat), + 34: syscalls.Supported("mkdirat", Mkdirat), + 35: syscalls.Supported("unlinkat", Unlinkat), + 36: syscalls.Supported("symlinkat", Symlinkat), + 37: syscalls.Supported("linkat", Linkat), + 38: syscalls.Supported("renameat", Renameat), + 39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), + 40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), + 41: syscalls.Error("pivot_root", syserror.EPERM, "", nil), + 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), + 44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), + 46: syscalls.Supported("ftruncate", Ftruncate), + 47: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), + 48: syscalls.Supported("faccessat", Faccessat), + 49: syscalls.Supported("chdir", Chdir), + 50: syscalls.Supported("fchdir", Fchdir), + 51: syscalls.Supported("chroot", Chroot), + 52: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), + 53: syscalls.Supported("fchmodat", Fchmodat), + 54: syscalls.Supported("fchownat", Fchownat), + 55: syscalls.Supported("fchown", Fchown), + 56: syscalls.Supported("openat", Openat), + 57: syscalls.Supported("close", Close), + 58: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), + 59: syscalls.Supported("pipe2", Pipe2), + 60: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations + 61: syscalls.Supported("getdents64", Getdents64), + 62: syscalls.Supported("lseek", Lseek), + 63: syscalls.Supported("read", Read), + 64: syscalls.Supported("write", Write), + 65: syscalls.Supported("readv", Readv), + 66: syscalls.Supported("writev", Writev), + 67: syscalls.Supported("pread64", Pread64), + 68: syscalls.Supported("pwrite64", Pwrite64), + 69: syscalls.Supported("preadv", Preadv), + 70: syscalls.Supported("pwritev", Pwritev), + 71: syscalls.Supported("sendfile", Sendfile), + 72: syscalls.Supported("pselect", Pselect), + 73: syscalls.Supported("ppoll", Ppoll), + 74: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 77: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 78: syscalls.Supported("readlinkat", Readlinkat), + 80: syscalls.Supported("fstat", Fstat), + 81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), + 82: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), + 83: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), + 84: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), + 85: syscalls.Supported("timerfd_create", TimerfdCreate), + 86: syscalls.Supported("timerfd_settime", TimerfdSettime), + 87: syscalls.Supported("timerfd_gettime", TimerfdGettime), + 88: syscalls.Supported("utimensat", Utimensat), + 89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), + 90: syscalls.Supported("capget", Capget), + 91: syscalls.Supported("capset", Capset), + 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), + 93: syscalls.Supported("exit", Exit), + 94: syscalls.Supported("exit_group", ExitGroup), + 95: syscalls.Supported("waitid", Waitid), + 96: syscalls.Supported("set_tid_address", SetTidAddress), + 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), + 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), + 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 101: syscalls.Supported("nanosleep", Nanosleep), + 102: syscalls.Supported("getitimer", Getitimer), + 103: syscalls.Supported("setitimer", Setitimer), + 104: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), + 105: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), + 106: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), + 107: syscalls.Supported("timer_create", TimerCreate), + 108: syscalls.Supported("timer_gettime", TimerGettime), + 109: syscalls.Supported("timer_getoverrun", TimerGetoverrun), + 110: syscalls.Supported("timer_settime", TimerSettime), + 111: syscalls.Supported("timer_delete", TimerDelete), + 112: syscalls.Supported("clock_settime", ClockSettime), + 113: syscalls.Supported("clock_gettime", ClockGettime), + 114: syscalls.Supported("clock_getres", ClockGetres), + 115: syscalls.Supported("clock_nanosleep", ClockNanosleep), + 116: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), + 117: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), + 118: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), + 119: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), + 120: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), + 121: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), + 122: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), + 123: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), + 124: syscalls.Supported("sched_yield", SchedYield), + 125: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), + 126: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), + 127: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), + 128: syscalls.Supported("restart_syscall", RestartSyscall), + 129: syscalls.Supported("kill", Kill), + 130: syscalls.Supported("tkill", Tkill), + 131: syscalls.Supported("tgkill", Tgkill), + 132: syscalls.Supported("sigaltstack", Sigaltstack), + 133: syscalls.Supported("rt_sigsuspend", RtSigsuspend), + 134: syscalls.Supported("rt_sigaction", RtSigaction), + 135: syscalls.Supported("rt_sigprocmask", RtSigprocmask), + 136: syscalls.Supported("rt_sigpending", RtSigpending), + 137: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), + 138: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), + 139: syscalls.Supported("rt_sigreturn", RtSigreturn), + 140: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), + 141: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), + 142: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), + 143: syscalls.Supported("setregid", Setregid), + 144: syscalls.Supported("setgid", Setgid), + 145: syscalls.Supported("setreuid", Setreuid), + 146: syscalls.Supported("setuid", Setuid), + 147: syscalls.Supported("setresuid", Setresuid), + 148: syscalls.Supported("getresuid", Getresuid), + 149: syscalls.Supported("setresgid", Setresgid), + 150: syscalls.Supported("getresgid", Getresgid), + 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 153: syscalls.Supported("times", Times), + 154: syscalls.Supported("setpgid", Setpgid), + 155: syscalls.Supported("getpgid", Getpgid), + 156: syscalls.Supported("getsid", Getsid), + 157: syscalls.Supported("setsid", Setsid), + 158: syscalls.Supported("getgroups", Getgroups), + 159: syscalls.Supported("setgroups", Setgroups), + 160: syscalls.Supported("uname", Uname), + 161: syscalls.Supported("sethostname", Sethostname), + 162: syscalls.Supported("setdomainname", Setdomainname), + 163: syscalls.Supported("getrlimit", Getrlimit), + 164: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), + 165: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), + 166: syscalls.Supported("umask", Umask), + 167: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), + 168: syscalls.Supported("getcpu", Getcpu), + 169: syscalls.Supported("gettimeofday", Gettimeofday), + 170: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), + 171: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), + 172: syscalls.Supported("getpid", Getpid), + 173: syscalls.Supported("getppid", Getppid), + 174: syscalls.Supported("getuid", Getuid), + 175: syscalls.Supported("geteuid", Geteuid), + 176: syscalls.Supported("getgid", Getgid), + 177: syscalls.Supported("getegid", Getegid), + 178: syscalls.Supported("gettid", Gettid), + 179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), + 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 186: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 187: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 190: syscalls.Supported("semget", Semget), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), + 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), + 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), + 196: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), + 197: syscalls.Supported("shmdt", Shmdt), + 198: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), + 199: syscalls.Supported("socketpair", SocketPair), + 200: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), + 201: syscalls.Supported("listen", Listen), + 202: syscalls.Supported("accept", Accept), + 203: syscalls.Supported("connect", Connect), + 204: syscalls.Supported("getsockname", GetSockName), + 205: syscalls.Supported("getpeername", GetPeerName), + 206: syscalls.Supported("sendto", SendTo), + 207: syscalls.Supported("recvfrom", RecvFrom), + 208: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), + 209: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), + 210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), + 211: syscalls.Supported("sendmsg", SendMsg), + 212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), + 213: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341) + 214: syscalls.Supported("brk", Brk), + 215: syscalls.Supported("munmap", Munmap), + 216: syscalls.Supported("mremap", Mremap), + 217: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), + 218: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), + 219: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), + 220: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), + 221: syscalls.Supported("execve", Execve), + 224: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), + 225: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), + 226: syscalls.Supported("mprotect", Mprotect), + 227: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), + 228: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 229: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 230: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), + 233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), + 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), + 235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), + 236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), + 237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), + 238: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), + 239: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) + 240: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), + 241: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), + 242: syscalls.Supported("accept4", Accept4), + 243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), + 260: syscalls.Supported("wait4", Wait4), + 261: syscalls.Supported("prlimit64", Prlimit64), + 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 264: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 265: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), + 267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), + 268: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) + 269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), + 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), + 273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), + 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) + 277: syscalls.Supported("seccomp", Seccomp), + 278: syscalls.Supported("getrandom", GetRandom), + 279: syscalls.Supported("memfd_create", MemfdCreate), + 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), + 281: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) + 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) + 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) + 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), + 286: syscalls.Supported("preadv2", Preadv2), + 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 291: syscalls.Supported("statx", Statx), + }, + Emulate: map[usermem.Addr]uintptr{}, + + Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { + t.Kernel().EmitUnimplementedEvent(t) + return 0, syserror.ENOSYS + }, +} diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 2e00a91ce..b9a8e3e21 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -1423,9 +1423,6 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { if err != nil { return err } - if dirPath { - return syserror.ENOENT - } return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error { if !fs.IsDir(d.Inode.StableAttr) { @@ -1436,7 +1433,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { return err } - return d.Remove(t, root, name) + return d.Remove(t, root, name, dirPath) }) } diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 3ab54271c..cd31e0649 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -72,6 +72,39 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file) } +// Readahead implements readahead(2). +func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + offset := args[1].Int64() + size := args[2].SizeT() + + file := t.GetFile(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the file is readable. + if !file.Flags().Read { + return 0, nil, syserror.EBADF + } + + // Check that the size is valid. + if int(size) < 0 { + return 0, nil, syserror.EINVAL + } + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Return EINVAL; if the underlying file type does not support readahead, + // then Linux will return EINVAL to indicate as much. In the future, we + // may extend this function to actually support readahead hints. + return 0, nil, syserror.EINVAL +} + // Pread64 implements linux syscall pread64(2). func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 0104a94c0..fb6efd5d8 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -20,7 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd" + "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -506,3 +509,77 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?") return 0, nil, syserror.EINTR } + +// sharedSignalfd is shared between the two calls. +func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) { + // Copy in the signal mask. + mask, err := copyInSigSet(t, sigset, sigsetsize) + if err != nil { + return 0, nil, err + } + + // Always check for valid flags, even if not creating. + if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 { + return 0, nil, syserror.EINVAL + } + + // Is this a change to an existing signalfd? + // + // The spec indicates that this should adjust the mask. + if fd != -1 { + file := t.GetFile(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Is this a signalfd? + if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok { + s.SetMask(mask) + return 0, nil, nil + } + + // Not a signalfd. + return 0, nil, syserror.EINVAL + } + + // Create a new file. + file, err := signalfd.New(t, mask) + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + // Set appropriate flags. + file.SetFlags(fs.SettableFileFlags{ + NonBlocking: flags&linux.SFD_NONBLOCK != 0, + }) + + // Create a new descriptor. + fd, err = t.NewFDFrom(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.SFD_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + + // Done. + return uintptr(fd), nil, nil +} + +// Signalfd implements the linux syscall signalfd(2). +func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + sigset := args[1].Pointer() + sigsetsize := args[2].SizeT() + return sharedSignalfd(t, fd, sigset, sigsetsize, 0) +} + +// Signalfd4 implements the linux syscall signalfd4(2). +func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + sigset := args[1].Pointer() + sigsetsize := args[2].SizeT() + flags := args[3].Int() + return sharedSignalfd(t, fd, sigset, sigsetsize, flags) +} diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 3bac4d90d..b5a72ce63 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - if optLen <= 0 { + if optLen < 0 { return 0, nil, syserror.EINVAL } if optLen > maxOptLen { diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 17e3dde1f..dd3a5807f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB total int64 n int64 err error - ch chan struct{} - inW bool - outW bool + inCh chan struct{} + outCh chan struct{} ) for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) @@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB break } - // Are we a registered waiter? - if ch == nil { - ch = make(chan struct{}, 1) - } - if !inW && !inFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - inFile.EventRegister(&w, EventMaskRead) - defer inFile.EventUnregister(&w) - inW = true // Registered. - } else if !outW && !outFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - outFile.EventRegister(&w, EventMaskWrite) - defer outFile.EventUnregister(&w) - outW = true // Registered. - } - - // Was anything registered? If no, everything is non-blocking. - if !inW && !outW { - break - } - - if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) { - // Something became ready, try again without blocking. - continue + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(EventMaskRead) == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, EventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } } - - // Block until there's data. - if err = t.Block(ch); err != nil { - break + if outFile.Readiness(EventMaskWrite) == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } } } @@ -91,22 +88,29 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } // Get files. + inFile := t.GetFile(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + + if !inFile.Flags().Read { + return 0, nil, syserror.EBADF + } + outFile := t.GetFile(outFD) if outFile == nil { return 0, nil, syserror.EBADF } defer outFile.DecRef() - inFile := t.GetFile(inFD) - if inFile == nil { + if !outFile.Flags().Write { return 0, nil, syserror.EBADF } - defer inFile.DecRef() - // Verify that the outfile Append flag is not set. Note that fs.Splice - // itself validates that the output file is writable. + // Verify that the outfile Append flag is not set. if outFile.Flags().Append { - return 0, nil, syserror.EBADF + return 0, nil, syserror.EINVAL } // Verify that we have a regular infile. This is a requirement; the @@ -142,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Length: count, SrcOffset: true, SrcStart: offset, - }, false) + }, outFile.Flags().NonBlocking) // Copy out the new offset. if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { @@ -152,12 +156,17 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Send data using splice. n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, - }, false) + }, outFile.Flags().NonBlocking) + } + + // Sendfile can't lose any data because inFD is always a regual file. + if n != 0 { + err = nil } // We can only pass a single file to handleIOError, so pick inFile // arbitrarily. This is used only for debugging purposes. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "sendfile", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile) } // Splice implements splice(2). @@ -174,12 +183,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. Note that unlike in Linux, this - // flag is applied consistently. We will have either fully blocking or - // non-blocking behavior below, regardless of the underlying files - // being spliced to. It's unclear if this is a bug or not yet. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -193,6 +196,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } defer inFile.DecRef() + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Construct our options. // // Note that exactly one of the underlying buffers must be a pipe. We @@ -240,17 +250,17 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if inOffset != 0 || outOffset != 0 { return 0, nil, syserror.ESPIPE } - default: - return 0, nil, syserror.EINVAL - } - // We may not refer to the same pipe; otherwise it's a continuous loop. - if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID { + // We may not refer to the same pipe; otherwise it's a continuous loop. + if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID { + return 0, nil, syserror.EINVAL + } + default: return 0, nil, syserror.EINVAL } // Splice data. - n, err := doSplice(t, outFile, inFile, opts, nonBlocking) + n, err := doSplice(t, outFile, inFile, opts, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile) @@ -268,9 +278,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -294,12 +301,20 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } + // The operation is non-blocking if anything is non-blocking. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Splice data. n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, Dup: true, - }, nonBlocking) + }, nonBlock) + + // Tee doesn't change a state of inFD, so it can't lose any data. + if n != 0 { + err = nil + } // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile) } diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index 4b3f043a2..b887fa9d7 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -15,6 +15,7 @@ package linux import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -228,41 +229,35 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use timer.Destroy() - var remaining time.Duration - // Did we just block for the entire duration? - if err == syserror.ETIMEDOUT { - remaining = 0 - } else { - remaining = dur - after.Sub(start) + switch err { + case syserror.ETIMEDOUT: + // Slept for entire timeout. + return nil + case syserror.ErrInterrupted: + // Interrupted. + remaining := dur - after.Sub(start) if remaining < 0 { remaining = time.Duration(0) } - } - // Copy out remaining time. - if err != nil && rem != usermem.Addr(0) { - timeleft := linux.NsecToTimespec(remaining.Nanoseconds()) - if err := copyTimespecOut(t, rem, &timeleft); err != nil { - return err + // Copy out remaining time. + if rem != 0 { + timeleft := linux.NsecToTimespec(remaining.Nanoseconds()) + if err := copyTimespecOut(t, rem, &timeleft); err != nil { + return err + } } - } - - // Did we just block for the entire duration? - if err == syserror.ETIMEDOUT { - return nil - } - // If interrupted, arrange for a restart with the remaining duration. - if err == syserror.ErrInterrupted { + // Arrange for a restart with the remaining duration. t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{ c: c, duration: remaining, rem: rem, }) return kernel.ERESTART_RESTARTBLOCK + default: + panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err)) } - - return err } // Nanosleep implements linux syscall Nanosleep(2). diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go index 271ace08e..748e8dd8d 100644 --- a/pkg/sentry/syscalls/linux/sys_utsname.go +++ b/pkg/sentry/syscalls/linux/sys_utsname.go @@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, syserror.EINVAL } - name, err := t.CopyInString(nameAddr, int(size)) - if err != nil { + name := make([]byte, size) + if _, err := t.CopyInBytes(nameAddr, name); err != nil { return 0, nil, err } - utsns.SetHostName(name) + utsns.SetHostName(string(name)) return 0, nil, nil } diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 8aa6a3017..beb43ba13 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD index b69603da3..fc7614fff 100644 --- a/pkg/sentry/unimpl/BUILD +++ b/pkg/sentry/unimpl/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -10,6 +11,12 @@ proto_library( deps = ["//pkg/sentry/arch:registers_proto"], ) +cc_proto_library( + name = "unimplemented_syscall_cc_proto", + visibility = ["//visibility:public"], + deps = [":unimplemented_syscall_proto"], +) + go_proto_library( name = "unimplemented_syscall_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto", diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index f4326706a..d6ef644d8 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -277,8 +277,3 @@ func TotalMemory(memSize, used uint64) uint64 { } return memSize } - -// IncrementalMappedAccounting controls whether host mapped memory is accounted -// incrementally during map translation. This may be modified during early -// initialization, and is read-only afterward. -var IncrementalMappedAccounting = false diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index a5b4206bb..cc5d25762 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "addr_range", diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go index 6eced660a..7b1f312b1 100644 --- a/pkg/sentry/usermem/usermem.go +++ b/pkg/sentry/usermem/usermem.go @@ -16,6 +16,7 @@ package usermem import ( + "bytes" "errors" "io" "strconv" @@ -270,11 +271,10 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts) // Look for the terminating zero byte, which may have occurred before // hitting err. - for i, c := range buf[done : done+n] { - if c == 0 { - return stringFromImmutableBytes(buf[:done+i]), nil - } + if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 { + return stringFromImmutableBytes(buf[:done+i]), nil } + done += n if err != nil { return stringFromImmutableBytes(buf[:done]), err diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 0f247bf77..eff4b44f6 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 86bde7fb3..3a9665800 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -102,7 +102,7 @@ type FileDescriptionImpl interface { // OnClose is called when a file descriptor representing the // FileDescription is closed. Note that returning a non-nil error does not // prevent the file descriptor from being closed. - OnClose() error + OnClose(ctx context.Context) error // StatusFlags returns file description status flags, as for // fcntl(F_GETFL). @@ -180,7 +180,7 @@ type FileDescriptionImpl interface { // ConfigureMMap mutates opts to implement mmap(2) for the file. Most // implementations that support memory mapping can call // GenericConfigureMMap with the appropriate memmap.Mappable. - ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error + ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) @@ -199,8 +199,11 @@ type Dirent struct { // Ino is the inode number. Ino uint64 - // Off is this Dirent's offset. - Off int64 + // NextOff is the offset of the *next* Dirent in the directory; that is, + // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will + // cause the next call to FileDescription.IterDirents() to yield the next + // Dirent. (The offset of the first Dirent in a directory is always 0.) + NextOff int64 } // IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents. diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index ba230da72..4fbad7840 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -45,7 +45,7 @@ type FileDescriptionDefaultImpl struct{} // OnClose implements FileDescriptionImpl.OnClose analogously to // file_operations::flush == NULL in Linux. -func (FileDescriptionDefaultImpl) OnClose() error { +func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error { return nil } @@ -117,7 +117,7 @@ func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error { // ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to // file_operations::mmap == NULL in Linux. -func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return syserror.ENODEV } diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 187e5410c..3aa73d911 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -31,14 +31,14 @@ type GetDentryOptions struct { // FilesystemImpl.MkdirAt(). type MkdirOptions struct { // Mode is the file mode bits for the created directory. - Mode uint16 + Mode linux.FileMode } // MknodOptions contains options to VirtualFilesystem.MknodAt() and // FilesystemImpl.MknodAt(). type MknodOptions struct { // Mode is the file type and mode bits for the created file. - Mode uint16 + Mode linux.FileMode // If Mode specifies a character or block device special file, DevMajor and // DevMinor are the major and minor device numbers for the created device. @@ -61,7 +61,7 @@ type OpenOptions struct { // If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the // created file. - Mode uint16 + Mode linux.FileMode } // ReadOptions contains options to FileDescription.PRead(), diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index 00665c939..a23c86fb1 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,6 +7,7 @@ go_library( name = "sleep", srcs = [ "commit_amd64.s", + "commit_arm64.s", "commit_asm.go", "commit_noasm.go", "sleep_unsafe.go", diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s new file mode 100644 index 000000000..d0ef15b20 --- /dev/null +++ b/pkg/sleep/commit_arm64.s @@ -0,0 +1,38 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +#define preparingG 1 + +// See commit_noasm.go for a description of commitSleep. +// +// func commitSleep(g uintptr, waitingG *uintptr) bool +TEXT ·commitSleep(SB),NOSPLIT,$0-24 + MOVD waitingG+8(FP), R0 + MOVD $preparingG, R1 + MOVD G+0(FP), R2 + + // Store the G in waitingG if it's still preparingG. If it's anything + // else it means a waker has aborted the sleep. +again: + LDAXR (R0), R3 + CMP R1, R3 + BNE ok + STLXR R2, (R0), R3 + CBNZ R3, again +ok: + CSET EQ, R0 + MOVB R0, ret+16(FP) + RET diff --git a/pkg/sleep/commit_asm.go b/pkg/sleep/commit_asm.go index 35e2cc337..75728a97d 100644 --- a/pkg/sleep/commit_asm.go +++ b/pkg/sleep/commit_asm.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 package sleep diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go index 686b1da3d..3af447fb9 100644 --- a/pkg/sleep/commit_noasm.go +++ b/pkg/sleep/commit_noasm.go @@ -13,7 +13,7 @@ // limitations under the License. // +build !race -// +build !amd64 +// +build !amd64,!arm64 package sleep diff --git a/pkg/state/BUILD b/pkg/state/BUILD index c0f3c658d..329904457 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index e70f4a79f..8a865d229 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index b149f9e02..bd3f9fd28 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index df37c7d5a..3c2b2b5ea 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "tcpip", diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 0d2637ee4..78df5a0b1 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 672f026b2..8ced960bb 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) { func newLoopbackStack() (*stack.Stack, *tcpip.Error) { // Create the stack and add a NIC. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, + }) if err := s.CreateNIC(NICID, loopback.New()); err != nil { return nil, err diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 3301967fb..d6c31bfa2 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "buffer", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index afcabd51d..096ad71ab 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -586,3 +586,103 @@ func Payload(want []byte) TransportChecker { } } } + +// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and +// potentially additional ICMPv4 header fields. +func ICMPv4(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) + } + + icmp := header.ICMPv4(last.Payload()) + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// ICMPv4Type creates a checker that checks the ICMPv4 Type field. +func ICMPv4Type(want header.ICMPv4Type) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + } + if got := icmpv4.Type(); got != want { + t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + } + } +} + +// ICMPv4Code creates a checker that checks the ICMPv4 Code field. +func ICMPv4Code(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + } + if got := icmpv4.Code(); got != want { + t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + } + } +} + +// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and +// potentially additional ICMPv6 header fields. +func ICMPv6(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) + } + + icmp := header.ICMPv6(last.Payload()) + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// ICMPv6Type creates a checker that checks the ICMPv6 Type field. +func ICMPv6Type(want header.ICMPv6Type) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + } + if got := icmpv6.Type(); got != want { + t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + } + } +} + +// ICMPv6Code creates a checker that checks the ICMPv6 Code field. +func ICMPv6Code(want byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + } + if got := icmpv6.Code(); got != want { + t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + } + } +} diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 29b30be9c..0c5c20cea 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 76ef02f13..a255231a3 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "header", diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c52c0d851..0cac6c0a5 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. @@ -25,13 +26,29 @@ type ICMPv4 []byte const ( // ICMPv4PayloadOffset defines the start of ICMP payload. - ICMPv4PayloadOffset = 4 + ICMPv4PayloadOffset = 8 // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. ICMPv4MinimumSize = 8 // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 + + // icmpv4ChecksumOffset is the offset of the checksum field + // in an ICMPv4 message. + icmpv4ChecksumOffset = 2 + + // icmpv4MTUOffset is the offset of the MTU field + // in a ICMPv4FragmentationNeeded message. + icmpv4MTUOffset = 6 + + // icmpv4IdentOffset is the offset of the ident field + // in a ICMPv4EchoRequest/Reply message. + icmpv4IdentOffset = 4 + + // icmpv4SequenceOffset is the offset of the sequence field + // in a ICMPv4EchoRequest/Reply message. + icmpv4SequenceOffset = 6 ) // ICMPv4Type is the ICMP type field described in RFC 792. @@ -72,12 +89,12 @@ func (b ICMPv4) SetCode(c byte) { b[1] = c } // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { - return binary.BigEndian.Uint16(b[2:]) + return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) } // SetChecksum sets the ICMP checksum field. func (b ICMPv4) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[2:], checksum) + binary.BigEndian.PutUint16(b[icmpv4ChecksumOffset:], checksum) } // SourcePort implements Transport.SourcePort. @@ -102,3 +119,51 @@ func (ICMPv4) SetDestinationPort(uint16) { func (b ICMPv4) Payload() []byte { return b[ICMPv4PayloadOffset:] } + +// MTU retrieves the MTU field from an ICMPv4 message. +func (b ICMPv4) MTU() uint16 { + return binary.BigEndian.Uint16(b[icmpv4MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv4 message. +func (b ICMPv4) SetMTU(mtu uint16) { + binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv4 message. +func (b ICMPv4) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv4IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv4 message. +func (b ICMPv4) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident) +} + +// Sequence retrieves the Sequence field from an ICMPv4 message. +func (b ICMPv4) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv4 message. +func (b ICMPv4) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence) +} + +// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, +// and payload. +func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := uint16(0) + for _, v := range vv.Views() { + xsum = Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 3cc57e234..1125a7d14 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. @@ -25,14 +26,18 @@ type ICMPv6 []byte const ( // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. - ICMPv6MinimumSize = 4 + ICMPv6MinimumSize = 8 + + // ICMPv6PayloadOffset is the offset of the payload in an + // ICMP packet. + ICMPv6PayloadOffset = 8 // ICMPv6ProtocolNumber is the ICMP transport protocol number. ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58 // ICMPv6NeighborSolicitMinimumSize is the minimum size of a // neighbor solicitation packet. - ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16 + ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16 // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. ICMPv6NeighborAdvertSize = 32 @@ -42,11 +47,27 @@ const ( // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP // destination unreachable packet. - ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4 + ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP // packet-too-big packet. - ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4 + ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + + // icmpv6ChecksumOffset is the offset of the checksum field + // in an ICMPv6 message. + icmpv6ChecksumOffset = 2 + + // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6 + // PacketTooBig message. + icmpv6MTUOffset = 4 + + // icmpv6IdentOffset is the offset of the ident field + // in a ICMPv6 Echo Request/Reply message. + icmpv6IdentOffset = 4 + + // icmpv6SequenceOffset is the offset of the sequence field + // in a ICMPv6 Echo Request/Reply message. + icmpv6SequenceOffset = 6 ) // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. @@ -89,12 +110,12 @@ func (b ICMPv6) SetCode(c byte) { b[1] = c } // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { - return binary.BigEndian.Uint16(b[2:]) + return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } // SetChecksum calculates and sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[2:], checksum) + binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } // SourcePort implements Transport.SourcePort. @@ -115,7 +136,60 @@ func (ICMPv6) SetSourcePort(uint16) { func (ICMPv6) SetDestinationPort(uint16) { } +// MTU retrieves the MTU field from an ICMPv6 message. +func (b ICMPv6) MTU() uint32 { + return binary.BigEndian.Uint32(b[icmpv6MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv6 message. +func (b ICMPv6) SetMTU(mtu uint32) { + binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv6 message. +func (b ICMPv6) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv6IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv6 message. +func (b ICMPv6) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident) +} + +// Sequence retrieves the Sequence field from an ICMPv6 message. +func (b ICMPv6) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv6 message. +func (b ICMPv6) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) +} + // Payload implements Transport.Payload. func (b ICMPv6) Payload() []byte { - return b[ICMPv6MinimumSize:] + return b[ICMPv6PayloadOffset:] +} + +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// IPv6 src/dst addresses and the payload. +func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := Checksum([]byte(src), 0) + xsum = Checksum([]byte(dst), xsum) + var upperLayerLength [4]byte + binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) + xsum = Checksum(upperLayerLength[:], xsum) + xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum) + for _, v := range vv.Views() { + xsum = Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 17fc9c68e..e5360e7c1 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -21,16 +21,18 @@ import ( ) const ( - versIHL = 0 - tos = 1 - totalLen = 2 - id = 4 - flagsFO = 6 - ttl = 8 - protocol = 9 - checksum = 10 - srcAddr = 12 - dstAddr = 16 + versIHL = 0 + tos = 1 + // IPv4TotalLenOffset is the offset of the total length field in the + // IPv4 header. + IPv4TotalLenOffset = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksum = 10 + srcAddr = 12 + dstAddr = 16 ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the @@ -103,6 +105,11 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP + // packet that every IPv4 capable host must be able to + // process/reassemble. + IPv4MinimumProcessableDatagramSize = 576 ) // Flags that may be set in an IPv4 packet. @@ -163,7 +170,7 @@ func (b IPv4) FragmentOffset() uint16 { // TotalLength returns the "total length" field of the ipv4 header. func (b IPv4) TotalLength() uint16 { - return binary.BigEndian.Uint16(b[totalLen:]) + return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } // Checksum returns the checksum field of the ipv4 header. @@ -209,7 +216,7 @@ func (b IPv4) SetTOS(v uint8, _ uint32) { // SetTotalLength sets the "total length" field of the ipv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { - binary.BigEndian.PutUint16(b[totalLen:], totalLength) + binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } // SetChecksum sets the checksum field of the ipv4 header. @@ -265,7 +272,7 @@ func (b IPv4) Encode(i *IPv4Fields) { // packets are produced. func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) { b.SetTotalLength(totalLength) - checksum := Checksum(b[totalLen:totalLen+2], partialChecksum) + checksum := Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum) b.SetChecksum(^checksum) } @@ -277,7 +284,7 @@ func (b IPv4) IsValid(pktSize int) bool { hlen := int(b.HeaderLength()) tlen := int(b.TotalLength()) - if hlen > tlen || tlen > pktSize { + if hlen < IPv4MinimumSize || hlen > tlen || tlen > pktSize { return false } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 31be42ce0..9d3abc0e4 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -22,12 +22,14 @@ import ( ) const ( - versTCFL = 0 - payloadLen = 4 - nextHdr = 6 - hopLimit = 7 - v6SrcAddr = 8 - v6DstAddr = 24 + versTCFL = 0 + // IPv6PayloadLenOffset is the offset of the PayloadLength field in + // IPv6 header. + IPv6PayloadLenOffset = 4 + nextHdr = 6 + hopLimit = 7 + v6SrcAddr = 8 + v6DstAddr = v6SrcAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -74,6 +76,13 @@ const ( // IPv6Version is the version of the ipv6 protocol. IPv6Version = 6 + // IPv6AllNodesMulticastAddress is a link-local multicast group that + // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all nodes on a link. + // + // The address is ff02::1. + IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 @@ -94,7 +103,7 @@ var IPv6EmptySubnet = func() tcpip.Subnet { // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { - return binary.BigEndian.Uint16(b[payloadLen:]) + return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) } // HopLimit returns the value of the "hop limit" field of the ipv6 header. @@ -119,13 +128,13 @@ func (b IPv6) Payload() []byte { // SourceAddress returns the "source address" field of the ipv6 header. func (b IPv6) SourceAddress() tcpip.Address { - return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize]) + return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize]) } // DestinationAddress returns the "destination address" field of the ipv6 // header. func (b IPv6) DestinationAddress() tcpip.Address { - return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize]) + return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize]) } // Checksum implements Network.Checksum. Given that IPv6 doesn't have a @@ -148,18 +157,18 @@ func (b IPv6) SetTOS(t uint8, l uint32) { // SetPayloadLength sets the "payload length" field of the ipv6 header. func (b IPv6) SetPayloadLength(payloadLength uint16) { - binary.BigEndian.PutUint16(b[payloadLen:], payloadLength) + binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength) } // SetSourceAddress sets the "source address" field of the ipv6 header. func (b IPv6) SetSourceAddress(addr tcpip.Address) { - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr) + copy(b[v6SrcAddr:][:IPv6AddressSize], addr) } // SetDestinationAddress sets the "destination address" field of the ipv6 // header. func (b IPv6) SetDestinationAddress(addr tcpip.Address) { - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr) + copy(b[v6DstAddr:][:IPv6AddressSize], addr) } // SetNextHeader sets the value of the "next header" field of the ipv6 header. @@ -178,8 +187,8 @@ func (b IPv6) Encode(i *IPv6Fields) { b.SetPayloadLength(i.PayloadLength) b[nextHdr] = i.NextHeader b[hopLimit] = i.HopLimit - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr) - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr) + b.SetSourceAddress(i.SrcAddr) + b.SetDestinationAddress(i.DstAddr) } // IsValid performs basic validation on the packet. @@ -219,6 +228,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool { return addr[0] == 0xff } +// IsV6UnicastAddress determines if the provided address is a valid IPv6 +// unicast (and specified) address. That is, IsV6UnicastAddress returns +// true if addr contains IPv6AddressSize bytes, is not the unspecified +// address and is not a multicast address. +func IsV6UnicastAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + + // Must not be unspecified + if addr == IPv6Any { + return false + } + + // Return if not a multicast. + return addr[0] != 0xff +} + // SolicitedNodeAddr computes the solicited-node multicast address. This is // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 // address. diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index c1f454805..74412c894 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -27,6 +27,11 @@ const ( udpChecksum = 6 ) +const ( + // UDPMaximumPacketSize is the largest possible UDP packet. + UDPMaximumPacketSize = 0xffff +) + // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 3fc14bacd..cc5f531e2 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "iptables", srcs = [ diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c40744b8e..18adb2085 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -44,14 +44,12 @@ type Endpoint struct { } // New creates a new channel endpoint. -func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ +func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { + return &Endpoint{ C: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } - - return stack.RegisterLinkEndpoint(e), e } // Drain removes all outbound packets from the channel and counts them. @@ -135,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } + +// Wait implements stack.LinkEndpoint.Wait. +func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index d786d8fdf..8fa9e3984 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -8,8 +9,8 @@ go_library( "endpoint.go", "endpoint_unsafe.go", "mmap.go", - "mmap_amd64.go", - "mmap_amd64_unsafe.go", + "mmap_stub.go", + "mmap_unsafe.go", "packet_dispatchers.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 77f988b9f..f80ac3435 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,6 +41,7 @@ package fdbased import ( "fmt" + "sync" "syscall" "golang.org/x/sys/unix" @@ -81,6 +82,19 @@ const ( PacketMMap ) +func (p PacketDispatchMode) String() string { + switch p { + case Readv: + return "Readv" + case RecvMMsg: + return "RecvMMsg" + case PacketMMap: + return "PacketMMap" + default: + return fmt.Sprintf("unknown packet dispatch mode %v", p) + } +} + type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -114,6 +128,9 @@ type endpoint struct { // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is // disabled. gsoMaxSize uint32 + + // wg keeps track of running goroutines. + wg sync.WaitGroup } // Options specify the details about the fd-based endpoint to be created. @@ -161,11 +178,20 @@ type Options struct { RXChecksumOffload bool } +// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT +// support in the host kernel. This allows us to use multiple FD's to receive +// from the same underlying NIC. The fanoutID needs to be the same for a given +// set of FD's that point to the same NIC. Trying to set the PACKET_FANOUT +// option for an FD with a fanoutID already in use by another FD for a different +// NIC will return an EINVAL. +var fanoutID = 1 + // New creates a new fd-based endpoint. // // Makes fd non-blocking, but does not take ownership of fd, which must remain -// open for the lifetime of the returned endpoint. -func New(opts *Options) (tcpip.LinkEndpointID, error) { +// open for the lifetime of the returned endpoint (until after the endpoint has +// stopped being using and Wait returns). +func New(opts *Options) (stack.LinkEndpoint, error) { caps := stack.LinkEndpointCapabilities(0) if opts.RXChecksumOffload { caps |= stack.CapabilityRXChecksumOffload @@ -190,7 +216,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } if len(opts.FDs) == 0 { - return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } e := &endpoint{ @@ -207,12 +233,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { for i := 0; i < len(e.fds); i++ { fd := e.fds[i] if err := syscall.SetNonblock(fd, true); err != nil { - return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) + return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) } isSocket, err := isSocketFD(fd) if err != nil { - return 0, err + return nil, err } if isSocket { if opts.GSOMaxSize != 0 { @@ -222,12 +248,16 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket) if err != nil { - return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err) } e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } - return stack.RegisterLinkEndpoint(e), nil + // Increment fanoutID to ensure that we don't re-use the same fanoutID for + // the next endpoint. + fanoutID++ + + return e, nil } func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) { @@ -247,7 +277,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher case *unix.SockaddrLinklayer: // enable PACKET_FANOUT mode is the underlying socket is // of type AF_PACKET. - const fanoutID = 1 const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG fanoutArg := fanoutID | fanoutType<<16 if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { @@ -290,7 +319,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { // saved, they stop sending outgoing packets and all incoming packets // are rejected. for i := range e.inboundDispatchers { - go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above. + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) } } @@ -320,6 +353,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } +// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop +// reading from its FD. +func (e *endpoint) Wait() { + e.wg.Wait() +} + // virtioNetHdr is declared in linux/virtio_net.h. type virtioNetHdr struct { flags uint8 @@ -435,14 +474,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf } // NewInjectable creates a new fd-based InjectableEndpoint. -func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) { +func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint { syscall.SetNonblock(fd, true) - e := &InjectableEndpoint{endpoint: endpoint{ + return &InjectableEndpoint{endpoint: endpoint{ fds: []int{fd}, mtu: mtu, caps: capabilities, }} - - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e305252d6..04406bc9a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context { } opt.FDs = []int{fds[1]} - epID, err := New(opt) + ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } - ep := stack.FindLinkEndpoint(epID).(*endpoint) c := &context{ t: t, diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 2dca173c2..8bfeb97e4 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -12,12 +12,183 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !linux !amd64 +// +build linux,amd64 linux,arm64 package fdbased -// Stubbed out version for non-linux/non-amd64 platforms. +import ( + "encoding/binary" + "syscall" -func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { - return nil, nil + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" +) + +const ( + tPacketAlignment = uintptr(16) + tpStatusKernel = 0 + tpStatusUser = 1 + tpStatusCopy = 2 + tpStatusLosing = 4 +) + +// We overallocate the frame size to accommodate space for the +// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding. +// +// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB +// +// NOTE: +// Frames need to be aligned at 16 byte boundaries. +// BlockSize needs to be page aligned. +// +// For details see PACKET_MMAP setting constraints in +// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +const ( + tpFrameSize = 65536 + 128 + tpBlockSize = tpFrameSize * 32 + tpBlockNR = 1 + tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize +) + +// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct +// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>. +func tPacketAlign(v uintptr) uintptr { + return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1)) +} + +// tPacketReq is the tpacket_req structure as described in +// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +type tPacketReq struct { + tpBlockSize uint32 + tpBlockNR uint32 + tpFrameSize uint32 + tpFrameNR uint32 +} + +// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h> +type tPacketHdr []byte + +const ( + tpStatusOffset = 0 + tpLenOffset = 8 + tpSnapLenOffset = 12 + tpMacOffset = 16 + tpNetOffset = 18 + tpSecOffset = 20 + tpUSecOffset = 24 +) + +func (t tPacketHdr) tpLen() uint32 { + return binary.LittleEndian.Uint32(t[tpLenOffset:]) +} + +func (t tPacketHdr) tpSnapLen() uint32 { + return binary.LittleEndian.Uint32(t[tpSnapLenOffset:]) +} + +func (t tPacketHdr) tpMac() uint16 { + return binary.LittleEndian.Uint16(t[tpMacOffset:]) +} + +func (t tPacketHdr) tpNet() uint16 { + return binary.LittleEndian.Uint16(t[tpNetOffset:]) +} + +func (t tPacketHdr) tpSec() uint32 { + return binary.LittleEndian.Uint32(t[tpSecOffset:]) +} + +func (t tPacketHdr) tpUSec() uint32 { + return binary.LittleEndian.Uint32(t[tpUSecOffset:]) +} + +func (t tPacketHdr) Payload() []byte { + return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()] +} + +// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. +// See: mmap_amd64_unsafe.go for implementation details. +type packetMMapDispatcher struct { + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // ringBuffer is only used when PacketMMap dispatcher is used and points + // to the start of the mmapped PACKET_RX_RING buffer. + ringBuffer []byte + + // ringOffset is the current offset into the ring buffer where the next + // inbound packet will be placed by the kernel. + ringOffset int +} + +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { + hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) + for hdr.tpStatus()&tpStatusUser == 0 { + event := rawfile.PollEvent{ + FD: int32(d.fd), + Events: unix.POLLIN | unix.POLLERR, + } + if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 { + if errno == syscall.EINTR { + continue + } + return nil, rawfile.TranslateErrno(errno) + } + if hdr.tpStatus()&tpStatusCopy != 0 { + // This frame is truncated so skip it after flipping the + // buffer to the kernel. + hdr.setTPStatus(tpStatusKernel) + d.ringOffset = (d.ringOffset + 1) % tpFrameNR + hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:]) + continue + } + } + + // Copy out the packet from the mmapped frame to a locally owned buffer. + pkt := make([]byte, hdr.tpSnapLen()) + copy(pkt, hdr.Payload()) + // Release packet to kernel. + hdr.setTPStatus(tpStatusKernel) + d.ringOffset = (d.ringOffset + 1) % tpFrameNR + return pkt, nil +} + +// dispatch reads packets from an mmaped ring buffer and dispatches them to the +// network stack. +func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { + pkt, err := d.readMMappedPacket() + if err != nil { + return false, err + } + var ( + p tcpip.NetworkProtocolNumber + remote, local tcpip.LinkAddress + ) + if d.e.hdrSize > 0 { + eth := header.Ethernet(pkt) + p = eth.Type() + remote = eth.SourceAddress() + local = eth.DestinationAddress() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + switch header.IPVersion(pkt) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + return true, nil + } + } + + pkt = pkt[d.e.hdrSize:] + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) + return true, nil } diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go deleted file mode 100644 index 029f86a18..000000000 --- a/pkg/tcpip/link/fdbased/mmap_amd64.go +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux,amd64 - -package fdbased - -import ( - "encoding/binary" - "syscall" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" -) - -const ( - tPacketAlignment = uintptr(16) - tpStatusKernel = 0 - tpStatusUser = 1 - tpStatusCopy = 2 - tpStatusLosing = 4 -) - -// We overallocate the frame size to accommodate space for the -// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding. -// -// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB -// -// NOTE: -// Frames need to be aligned at 16 byte boundaries. -// BlockSize needs to be page aligned. -// -// For details see PACKET_MMAP setting constraints in -// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt -const ( - tpFrameSize = 65536 + 128 - tpBlockSize = tpFrameSize * 32 - tpBlockNR = 1 - tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize -) - -// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct -// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>. -func tPacketAlign(v uintptr) uintptr { - return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1)) -} - -// tPacketReq is the tpacket_req structure as described in -// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt -type tPacketReq struct { - tpBlockSize uint32 - tpBlockNR uint32 - tpFrameSize uint32 - tpFrameNR uint32 -} - -// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h> -type tPacketHdr []byte - -const ( - tpStatusOffset = 0 - tpLenOffset = 8 - tpSnapLenOffset = 12 - tpMacOffset = 16 - tpNetOffset = 18 - tpSecOffset = 20 - tpUSecOffset = 24 -) - -func (t tPacketHdr) tpLen() uint32 { - return binary.LittleEndian.Uint32(t[tpLenOffset:]) -} - -func (t tPacketHdr) tpSnapLen() uint32 { - return binary.LittleEndian.Uint32(t[tpSnapLenOffset:]) -} - -func (t tPacketHdr) tpMac() uint16 { - return binary.LittleEndian.Uint16(t[tpMacOffset:]) -} - -func (t tPacketHdr) tpNet() uint16 { - return binary.LittleEndian.Uint16(t[tpNetOffset:]) -} - -func (t tPacketHdr) tpSec() uint32 { - return binary.LittleEndian.Uint32(t[tpSecOffset:]) -} - -func (t tPacketHdr) tpUSec() uint32 { - return binary.LittleEndian.Uint32(t[tpUSecOffset:]) -} - -func (t tPacketHdr) Payload() []byte { - return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()] -} - -// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. -// See: mmap_amd64_unsafe.go for implementation details. -type packetMMapDispatcher struct { - // fd is the file descriptor used to send and receive packets. - fd int - - // e is the endpoint this dispatcher is attached to. - e *endpoint - - // ringBuffer is only used when PacketMMap dispatcher is used and points - // to the start of the mmapped PACKET_RX_RING buffer. - ringBuffer []byte - - // ringOffset is the current offset into the ring buffer where the next - // inbound packet will be placed by the kernel. - ringOffset int -} - -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { - hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) - for hdr.tpStatus()&tpStatusUser == 0 { - event := rawfile.PollEvent{ - FD: int32(d.fd), - Events: unix.POLLIN | unix.POLLERR, - } - if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 { - if errno == syscall.EINTR { - continue - } - return nil, rawfile.TranslateErrno(errno) - } - if hdr.tpStatus()&tpStatusCopy != 0 { - // This frame is truncated so skip it after flipping the - // buffer to the kernel. - hdr.setTPStatus(tpStatusKernel) - d.ringOffset = (d.ringOffset + 1) % tpFrameNR - hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:]) - continue - } - } - - // Copy out the packet from the mmapped frame to a locally owned buffer. - pkt := make([]byte, hdr.tpSnapLen()) - copy(pkt, hdr.Payload()) - // Release packet to kernel. - hdr.setTPStatus(tpStatusKernel) - d.ringOffset = (d.ringOffset + 1) % tpFrameNR - return pkt, nil -} - -// dispatch reads packets from an mmaped ring buffer and dispatches them to the -// network stack. -func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { - pkt, err := d.readMMappedPacket() - if err != nil { - return false, err - } - var ( - p tcpip.NetworkProtocolNumber - remote, local tcpip.LinkAddress - ) - if d.e.hdrSize > 0 { - eth := header.Ethernet(pkt) - p = eth.Type() - remote = eth.SourceAddress() - local = eth.DestinationAddress() - } else { - // We don't get any indication of what the packet is, so try to guess - // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(pkt) { - case header.IPv4Version: - p = header.IPv4ProtocolNumber - case header.IPv6Version: - p = header.IPv6ProtocolNumber - default: - return true, nil - } - } - - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) - return true, nil -} diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go new file mode 100644 index 000000000..67be52d67 --- /dev/null +++ b/pkg/tcpip/link/fdbased/mmap_stub.go @@ -0,0 +1,23 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !linux !amd64,!arm64 + +package fdbased + +// Stubbed out version for non-linux/non-amd64/non-arm64 platforms. + +func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + return nil, nil +} diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go index 47cb1d1cc..3894185ae 100644 --- a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go +++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux,amd64 +// +build linux,amd64 linux,arm64 package fdbased diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ab6a53988..b36629d2c 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -32,8 +32,8 @@ type endpoint struct { // New creates a new loopback endpoint. This link-layer endpoint just turns // outbound packets into inbound packets. -func New() tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{}) +func New() stack.LinkEndpoint { + return &endpoint{} } // Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- @@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa return nil } + +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index ea12ef1ac..1bab380b0 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a577a3d52..7c946101d 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -104,10 +104,16 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) * return endpoint.WriteRawPacket(dest, packet) } +// Wait implements stack.LinkEndpoint.Wait. +func (m *InjectableEndpoint) Wait() { + for _, ep := range m.routes { + ep.Wait() + } +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. -func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) { - e := &InjectableEndpoint{ +func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { + return &InjectableEndpoint{ routes: routes, } - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 174b9330f..3086fec00 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc if err != nil { t.Fatal("Failed to create socket pair:", err) } - _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) + underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - _, endpoint := NewInjectableEndpoint(routes) + endpoint := NewInjectableEndpoint(routes) return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP } diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 6e3a7a9d7..2e8bc772a 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -6,8 +6,9 @@ go_library( name = "rawfile", srcs = [ "blockingpoll_amd64.s", - "blockingpoll_amd64_unsafe.go", - "blockingpoll_unsafe.go", + "blockingpoll_arm64.s", + "blockingpoll_noyield_unsafe.go", + "blockingpoll_yield_unsafe.go", "errors.go", "rawfile_unsafe.go", ], diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s new file mode 100644 index 000000000..b62888b93 --- /dev/null +++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s @@ -0,0 +1,42 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// BlockingPoll makes the ppoll() syscall while calling the version of +// entersyscall that relinquishes the P so that other Gs can run. This is meant +// to be called in cases when the syscall is expected to block. +// +// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno) +TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 + BL ·callEntersyscallblock(SB) + MOVD fds+0(FP), R0 + MOVD nfds+8(FP), R1 + MOVD timeout+16(FP), R2 + MOVD $0x0, R3 // sigmask parameter which isn't used here + MOVD $0x49, R8 // SYS_PPOLL + SVC + CMP $0xfffffffffffff001, R0 + BLS ok + MOVD $-1, R1 + MOVD R1, n+24(FP) + NEG R0, R0 + MOVD R0, err+32(FP) + BL ·callExitsyscall(SB) + RET +ok: + MOVD R0, n+24(FP) + MOVD $0, err+32(FP) + BL ·callExitsyscall(SB) + RET diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go index 84dc0e918..621ab8d29 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux,!amd64 +// +build linux,!amd64,!arm64 package rawfile @@ -22,7 +22,7 @@ import ( ) // BlockingPoll is just a stub function that forwards to the ppoll() system call -// on non-amd64 platforms. +// on non-amd64 and non-arm64 platforms. func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) { n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index 47039a446..dda3b10a6 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux,amd64 +// +build linux,amd64 linux,arm64 // +build go1.12 // +build !go1.14 @@ -25,6 +25,12 @@ import ( _ "unsafe" // for go:linkname ) +// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the +// version of entersyscall that relinquishes the P so that other Gs can +// run. This is meant to be called in cases when the syscall is expected to +// block. On non amd64/arm64 platforms it just forwards to the ppoll() system +// call. +// //go:noescape func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index f2998aa98..0a5ea3dc4 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 94725cb11..330ed5e94 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 160a8f864..de1ce043d 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 834ea5c40..9e71d4edf 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -94,7 +94,7 @@ type endpoint struct { // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) { +func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { e := &endpoint{ mtu: mtu, bufferSize: bufferSize, @@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc } if err := e.tx.init(bufferSize, &tx); err != nil { - return 0, err + return nil, err } if err := e.rx.init(bufferSize, &rx); err != nil { e.tx.cleanup() - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } // Close frees all resources associated with the endpoint. @@ -132,7 +132,8 @@ func (e *endpoint) Close() { } } -// Wait waits until all workers have stopped after a Close() call. +// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have +// stopped after a Close() call. func (e *endpoint) Wait() { e.completed.Wait() } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 98036f367..0e9ba0846 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) if err != nil { t.Fatalf("New failed: %v", err) } - c.ep = stack.FindLinkEndpoint(id).(*endpoint) + c.ep = ep.(*endpoint) c.ep.Attach(c) return c diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 36c8c46fc..e401dce44 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -58,10 +58,10 @@ type endpoint struct { // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. -func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), - }) +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + return &endpoint{ + lower: lower, + } } func zoneOffset() (int32, error) { @@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { // snapLen is the maximum amount of a packet to be saved. Packets with a length // less than or equal too snapLen will be saved in their entirety. Longer // packets will be truncated to snapLen. -func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) { +func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) { if err := writePCAPHeader(file, snapLen); err != nil { - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), + return &endpoint{ + lower: lower, file: file, maxPCAPLen: snapLen, - }), nil + }, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is @@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.lower.WritePacket(r, gso, hdr, payload, protocol) } +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 2597d4b3e..0746dc8ec 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 3b6ac2ff7..5a1791cb5 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -40,11 +40,10 @@ type Endpoint struct { // New creates a new waitable link-layer endpoint. It wraps around another // endpoint and allows the caller to block new write/dispatch calls and wait for // the inflight ones to finish before returning. -func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ - lower: stack.FindLinkEndpoint(lower), +func New(lower stack.LinkEndpoint) *Endpoint { + return &Endpoint{ + lower: lower, } - return stack.RegisterLinkEndpoint(e), e } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. @@ -121,3 +120,6 @@ func (e *Endpoint) WaitWrite() { func (e *Endpoint) WaitDispatch() { e.dispatchGate.Close() } + +// Wait implements stack.LinkEndpoint.Wait. +func (e *Endpoint) Wait() {} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 56e18ecb0..ae23c96b7 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -70,9 +70,12 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P return nil } +// Wait implements stack.LinkEndpoint.Wait. +func (*countedEndpoint) Wait() {} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Write and check that it goes through. wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) @@ -97,7 +100,7 @@ func TestWaitWrite(t *testing.T) { func TestWaitDispatch(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Check that attach happens. wep.Attach(ep) @@ -139,7 +142,7 @@ func TestOtherMethods(t *testing.T) { hdrLen: hdrLen, linkAddr: linkAddr, } - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) if v := wep.MTU(); v != mtu { t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index f36f49453..9d16ff8c9 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d95d44f56..df0d3a8c0 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ea7296e6a..6b1e854dc 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -16,9 +16,9 @@ // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. // -// To use it in the networking stack, pass arp.ProtocolName as one of the -// network protocols when calling stack.New. Then add an "arp" address to -// every NIC on the stack that should respond to ARP requests. That is: +// To use it in the networking stack, pass arp.NewProtocol() as one of the +// network protocols when calling stack.New. Then add an "arp" address to every +// NIC on the stack that should respond to ARP requests. That is: // // if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { // // handle err @@ -33,9 +33,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ARP protocol name. - ProtocolName = "arp" - // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber @@ -82,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error { return tcpip.ErrNotSupported } @@ -204,8 +201,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an ARP network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 4c4b54469..88b57ec03 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -44,14 +44,19 @@ type testContext struct { } func newTestContext(t *testing.T) *testContext { - s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + }) const defaultMTU = 65536 - id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(256, defaultMTU, stackLinkAddr) + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -73,7 +78,7 @@ func newTestContext(t *testing.T) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 118bfc763..2cad0a0b6 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "reassembler_list", @@ -27,6 +28,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/tcpip", "//pkg/tcpip/buffer", ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 1628a82be..6da5238ec 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -17,6 +17,7 @@ package fragmentation import ( + "fmt" "log" "sync" "time" @@ -82,7 +83,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t // Process processes an incoming fragment belonging to an ID // and returns a complete packet when all the packets belonging to that ID have been received. -func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) { +func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { f.mu.Lock() r, ok := f.reassemblers[id] if ok && r.tooOld(f.timeout) { @@ -97,8 +98,15 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } f.mu.Unlock() - res, done, consumed := r.process(first, last, more, vv) - + res, done, consumed, err := r.process(first, last, more, vv) + if err != nil { + // We probably got an invalid sequence of fragments. Just + // discard the reassembler and move on. + f.mu.Lock() + f.release(r) + f.mu.Unlock() + return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err) + } f.mu.Lock() f.size += consumed if done { @@ -114,7 +122,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } } f.mu.Unlock() - return res, done + return res, done, nil } func (f *Fragmentation) release(r *reassembler) { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 799798544..72c0f53be 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -83,7 +83,10 @@ func TestFragmentationProcess(t *testing.T) { t.Run(c.comment, func(t *testing.T) { f := NewFragmentation(1024, 512, DefaultReassembleTimeout) for i, in := range c.in { - vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) + vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) + if err != nil { + t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) + } if !reflect.DeepEqual(vv, c.out[i].vv) { t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) } @@ -114,7 +117,10 @@ func TestReassemblingTimeout(t *testing.T) { time.Sleep(2 * timeout) // Send another fragment that completes a packet. // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done := f.Process(0, 1, 1, false, vv(1, "1")) + _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) + if err != nil { + t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) + } if done { t.Errorf("Fragmentation does not respect the reassembling timeout.") } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 8037f734b..9e002e396 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) { +func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -86,7 +86,7 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, false, consumed + return buffer.VectorisedView{}, false, consumed, nil } if r.updateHoles(first, last, more) { // We store the incoming packet only if it filled some holes. @@ -96,13 +96,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise } // Check if all the holes have been deleted and we are ready to reassamble. if r.deleted < len(r.holes) { - return buffer.VectorisedView{}, false, consumed + return buffer.VectorisedView{}, false, consumed, nil } res, err := r.heap.reassemble() if err != nil { - panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err)) + return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err) } - return res, true, consumed + return res, true, consumed, nil } func (r *reassembler) tooOld(timeout time.Duration) bool { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6bbfcd97f..f644a8b08 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*testObject) Wait() {} + // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. @@ -169,7 +172,10 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen } func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ @@ -182,7 +188,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ @@ -221,7 +230,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -319,7 +328,8 @@ func TestIPv4ReceiveControl(t *testing.T) { icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) icmp.SetType(header.ICMPv4DstUnreachable) icmp.SetCode(c.code) - copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef}) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) // Create the inner IPv4 header. ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) @@ -450,7 +460,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -539,7 +549,7 @@ func TestIPv6ReceiveControl(t *testing.T) { defer ep.Close() - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4 + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize } @@ -559,10 +569,11 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) icmp.SetType(c.typ) icmp.SetCode(c.code) - copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) ip.Encode(&header.IPv6Fields{ PayloadLength: 100, NextHeader: 10, @@ -574,7 +585,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Build the fragmentation header if needed. if c.fragmentOffset != nil { ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) frag.Encode(&header.IPv6FragmentFields{ NextHeader: 10, FragmentOffset: *c.fragmentOffset, diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index be84fa63d..58e537aad 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 497164cbb..50b363dc4 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,8 +15,6 @@ package ipv4 import ( - "encoding/binary" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -97,7 +95,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -117,7 +115,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V e.handleControl(stack.ControlPortUnreachable, 0, vv) case header.ICMPv4FragmentationNeeded: - mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:])) + mtu := uint32(h.MTU()) e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b7a06f525..5cd895ff0 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -14,9 +14,9 @@ // Package ipv4 contains the implementation of the ipv4 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv4.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv4.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv4.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv4 @@ -32,9 +32,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv4 protocol name. - ProtocolName = "ipv4" - // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -42,6 +39,9 @@ const ( // TotalLength field of the ipv4 header. MaxTotalSize = 0xffff + // DefaultTTL is the default time-to-live value for this endpoint. + DefaultTTL = 64 + // buckets is the number of identifier buckets. buckets = 2048 ) @@ -53,6 +53,7 @@ type endpoint struct { linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation + protocol *protocol } // NewEndpoint creates a new ipv4 endpoint. @@ -64,6 +65,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi linkEP: linkEP, dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + protocol: p, } return e, nil @@ -71,7 +73,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi // DefaultTTL is the default time-to-live value for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -197,21 +199,22 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payload.Size()) id := uint32(0) if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) } ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, ID: uint16(id), - TTL: ttl, - Protocol: uint8(protocol), + TTL: params.TTL, + TOS: params.TOS, + Protocol: uint8(params.Protocol), SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) @@ -267,7 +270,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect if payload.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) } ip.SetID(uint16(id)) } @@ -294,6 +297,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { headerView := vv.First() h := header.IPv4(headerView) if !h.IsValid(vv.Size()) { + r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -304,10 +308,31 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { + if vv.Size() == 0 { + // Drop the packet as it's marked as a fragment but has + // no payload. + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } // The packet is a fragment, let's try to reassemble it. last := h.FragmentOffset() + uint16(vv.Size()) - 1 + // Drop the packet if the fragmentOffset is incorrect. i.e the + // combination of fragmentOffset and vv.size() causes a wrap + // around resulting in last being less than the offset. + if last < h.FragmentOffset() { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } var ready bool - vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + var err error + vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } if !ready { return } @@ -325,14 +350,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() {} -type protocol struct{} +type protocol struct { + ids []uint32 + hashIV uint32 -// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 } // Number returns the ipv4 protocol number. @@ -358,12 +383,34 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -378,7 +425,7 @@ func calculateMTU(mtu uint32) uint32 { // hashRoute calculates a hash value for the given route. It uses the source & // destination address, the transport protocol number, and a random initial // value (generated once on initialization) to generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { t := r.LocalAddress a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 t = r.RemoteAddress @@ -386,22 +433,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -var ( - ids []uint32 - hashIV uint32 -) - -func init() { - ids = make([]uint32, buckets) +// NewProtocol returns an IPv4 network protocol. +func NewProtocol() stack.NetworkProtocol { + ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. r := hash.RandN32(1 + buckets) for i := range ids { ids[i] = r[i] } - hashIV = r[buckets] + hashIV := r[buckets] - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) + return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 1b5a55bea..85ab0e3bc 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -33,20 +33,20 @@ import ( ) func TestExcludeBroadcast(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) const defaultMTU = 65536 - id, _ := channel.New(256, defaultMTU, "") + ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { - id = sniffer.New(id) + ep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { t.Fatalf("AddAddress failed: %v", err) } @@ -184,15 +184,12 @@ type errorChannel struct { // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket // will return successive errors from packetCollectorErrors until the list is // empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { - _, e := channel.New(size, mtu, linkAddr) - ec := errorChannel{ - Endpoint: e, +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { + return &errorChannel{ + Endpoint: channel.New(size, mtu, linkAddr), Ch: make(chan packetInfo, size), packetCollectorErrors: packetCollectorErrors, } - - return stack.RegisterLinkEndpoint(e), &ec } // packetInfo holds all the information about an outbound packet. @@ -241,10 +238,11 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. - s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) - _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - linkEPId := stack.RegisterLinkEndpoint(linkEP) - s.CreateNIC(1, linkEPId) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + s.CreateNIC(1, ep) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -266,7 +264,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 } return context{ Route: r, - linkEP: linkEP, + linkEP: ep, } } @@ -304,7 +302,7 @@ func TestFragmentation(t *testing.T) { Payload: payload.Clone([]buffer.View{}), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) if err != nil { t.Errorf("err got %v, want %v", err, nil) } @@ -351,7 +349,7 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -368,3 +366,117 @@ func TestFragmentationErrors(t *testing.T) { }) } } + +func TestInvalidFragments(t *testing.T) { + // These packets have both IHL and TotalLength set to 0. + testCases := []struct { + name string + packets [][]byte + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + }{ + { + "ihl_totallen_zero_valid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + "ihl_totallen_zero_invalid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + // Total Length of 37(20 bytes IP header + 17 bytes of + // payload) + // Frag Offset of 0x1ffe = 8190*8 = 65520 + // Leading to the fragment end to be past 65535. + "ihl_totallen_valid_invalid_frag_offset_1", + [][]byte{ + {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + // The following 3 tests were found by running a fuzzer and were + // triggering a panic in the IPv4 reassembler code. + { + "ihl_less_than_ipv4_minimum_size_1", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_2", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_3", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "fragment_with_short_total_len_extra_payload", + [][]byte{ + {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + { + "multiple_fragments_with_more_fragments_set_to_false", + [][]byte{ + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + 1, + 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + const nicid tcpip.NICID = 42 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + }, + }) + + var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) + var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) + ep := channel.New(10, 1500, linkAddr) + s.CreateNIC(nicid, sniffer.New(ep)) + + for _, pkt := range tc.packets { + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt})) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { + t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { + t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index fae7f4507..f06622a8b 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -23,7 +24,11 @@ go_library( go_test( name = "ipv6_test", size = "small", - srcs = ["icmp_test.go"], + srcs = [ + "icmp_test.go", + "ipv6_test.go", + "ndp_test.go", + ], embed = [":ipv6"], deps = [ "//pkg/tcpip", @@ -33,6 +38,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 5e6a59e91..f543ceb92 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,14 +15,21 @@ package ipv6 import ( - "encoding/binary" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +const ( + // ndpHopLimit is the expected IP hop limit value of 255 for received + // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, + // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently + // drop the NDP packet. All outgoing NDP packets must use this value for + // its IP hop limit field. + ndpHopLimit = 255 +) + // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP @@ -73,6 +80,21 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V } h := header.ICMPv6(v) + // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and + // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field + // in the IPv6 header is not set to 255. + switch h.Type() { + case header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborAdvert, + header.ICMPv6RouterSolicit, + header.ICMPv6RouterAdvert, + header.ICMPv6RedirectMsg: + if header.IPv6(netHeader).HopLimit() != ndpHopLimit { + received.Invalid.Increment() + return + } + } + // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv6PacketTooBig: @@ -82,7 +104,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V return } vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:]) + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) case header.ICMPv6DstUnreachable: @@ -99,19 +121,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:16]) + targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) pkt.SetType(header.ICMPv6NeighborAdvert) @@ -132,9 +150,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + // TODO(tamird/ghanan): there exists an explicit NDP option that is + // used to update the neighbor table with link addresses for a + // neighbor from an NS (see the Source Link Layer option RFC + // 4861 section 4.6.1 and section 7.2.3). + // + // Furthermore, the entirety of NDP handling here seems to be + // contradicted by RFC 4861. + e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) + // + // 7.1.2. Validation of Neighbor Advertisements + // + // The IP Hop Limit field has a value of 255, i.e., the packet + // could not possibly have been forwarded by a router. + if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ndpHopLimit, TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -146,7 +179,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:16]) + targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) @@ -158,14 +191,13 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) + if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -235,14 +267,14 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr pkt[icmpV6LengthOffset] = 1 copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) length := uint16(hdr.UsedLength()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: defaultIPv6HopLimit, + HopLimit: ndpHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) @@ -274,24 +306,3 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } return "", false } - -func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := header.Checksum([]byte(src), 0) - xsum = header.Checksum([]byte(dst), xsum) - var upperLayerLength [4]byte - binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) - xsum = header.Checksum(upperLayerLength[:], xsum) - xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum) - for _, v := range vv.Views() { - xsum = header.Checksum(v, xsum) - } - - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^header.Checksum(h, xsum) - h[2], h[3] = h2, h3 - - return xsum -} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d0dc72506..dd3c4d7c4 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,7 +15,6 @@ package ipv6 import ( - "fmt" "reflect" "strings" "testing" @@ -81,10 +80,12 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li } func TestICMPCounts(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -142,7 +143,7 @@ func TestICMPCounts(t *testing.T) { ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: r.DefaultTTL(), + HopLimit: ndpHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) @@ -153,7 +154,7 @@ func TestICMPCounts(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) pkt := header.ICMPv6(hdr.Prepend(typ.size)) pkt.SetType(typ.typ) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) handleIPv6Payload(hdr) } @@ -177,13 +178,10 @@ func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { t := v.Type() for i := 0; i < v.NumField(); i++ { v := v.Field(i) - switch v.Kind() { - case reflect.Ptr: - f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter)) - case reflect.Struct: + if s, ok := v.Interface().(*tcpip.StatCounter); ok { + f(t.Field(i).Name, s) + } else { visitStats(v, f) - default: - panic(fmt.Sprintf("unexpected type %s", v.Type())) } } } @@ -206,41 +204,38 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ - s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), + s0: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), + s1: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), } const defaultMTU = 65536 - _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) - c.linkEP0 = linkEP0 - wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} - id0 := stack.RegisterLinkEndpoint(wrappedEP0) + c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + + wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { - id0 = sniffer.New(id0) + wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, id0); err != nil { + if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { - t.Fatalf("AddAddress sn lladdr0: %v", err) - } - _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) - c.linkEP1 = linkEP1 - wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} - id1 := stack.RegisterLinkEndpoint(wrappedEP1) - if err := c.s1.CreateNIC(1, id1); err != nil { + c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) + if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { - t.Fatalf("AddAddress sn lladdr1: %v", err) - } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) if err != nil { @@ -321,7 +316,7 @@ func TestLinkResolution(t *testing.T) { hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) payload := tcpip.SlicePayload(hdr.View()) // We can't send our payload directly over the route because that diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331a8bdaa..cd1e34085 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -14,13 +14,15 @@ // Package ipv6 contains the implementation of the ipv6 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv6.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv6.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv6.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv6 import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -28,9 +30,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv6 protocol name. - ProtocolName = "ipv6" - // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber @@ -38,9 +37,9 @@ const ( // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff - // defaultIPv6HopLimit is the default hop limit for IPv6 Packets - // egressed by Netstack. - defaultIPv6HopLimit = 255 + // DefaultTTL is the default hop limit for IPv6 Packets egressed by + // Netstack. + DefaultTTL = 64 ) type endpoint struct { @@ -50,11 +49,12 @@ type endpoint struct { linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher + protocol *protocol } // DefaultTTL is the default hop limit for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -98,13 +98,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { length := uint16(hdr.UsedLength() + payload.Size()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, - NextHeader: uint8(protocol), - HopLimit: ttl, + NextHeader: uint8(params.Protocol), + HopLimit: params.TTL, + TrafficClass: params.TOS, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) @@ -158,14 +159,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { // Close cleans up resources associated with the endpoint. func (*endpoint) Close() {} -type protocol struct{} - -// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} +type protocol struct { + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 } // Number returns the ipv6 protocol number. @@ -198,17 +196,40 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, + protocol: p, }, nil } // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -221,8 +242,7 @@ func calculateMTU(mtu uint32) uint32 { return maxPayloadSize } -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an IPv6 network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go new file mode 100644 index 000000000..deaa9b7f3 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -0,0 +1,266 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // The least significant 3 bytes are the same as addr2 so both addr2 and + // addr3 will have the same solicited-node address. + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" +) + +// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the +// expected Neighbor Advertisement received count after receiving the packet. +func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + // Receive ICMP packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stats := s.Stats().ICMP.V6PacketsReceived + + if got := stats.NeighborAdvert.Value(); got != want { + t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) + } +} + +// testReceiveUDP tests receiving a UDP packet from src to dst. want is the +// expected UDP received count after receiving the packet. +func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + + // Receive UDP Packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + + // UDP pseudo-header checksum. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) + + // UDP checksum + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stat := s.Stats().UDP.PacketsReceived + + if got := stat.Value(); got != want { + t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) + } +} + +// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and +// UDP packets destined to the IPv6 link-local all-nodes multicast address. +func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should receive a packet destined to the all-nodes + // multicast address. + test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) + }) + } +} + +// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP +// packets destined to the IPv6 solicited-node address of an assigned IPv6 +// address. +func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, + } + + snmc := header.SolicitedNodeAddr(addr2) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as we haven't added + // those addresses. + test.rxf(t, s, e, addr1, snmc, 0) + + if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + } + + // Should receive a packet destined to the solicited + // node address of addr2/addr3 now that we have added + // added addr2. + test.rxf(t, s, e, addr1, snmc, 1) + + if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have added addr3. + test.rxf(t, s, e, addr1, snmc, 2) + + if err := s.RemoveAddress(1, addr2); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have removed addr2. + test.rxf(t, s, e, addr1, snmc, 3) + + if err := s.RemoveAddress(1, addr3); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as both of them got + // removed. + test.rxf(t, s, e, addr1, snmc, 3) + }) + } +} + +// TestAddIpv6Address tests adding IPv6 addresses. +func TestAddIpv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + }{ + // This test is in response to b/140943433. + { + "Nil", + tcpip.Address([]byte(nil)), + }, + { + "ValidUnicast", + addr1, + }, + { + "ValidLinkLocalUnicast", + lladdr0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go new file mode 100644 index 000000000..e30791fe3 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -0,0 +1,181 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" +) + +// setupStackAndEndpoint creates a stack with a single NIC with a link-local +// address llladdr and an IPv6 endpoint to a remote with link-local address +// rlladdr +func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { + t.Helper() + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) + + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) + } + + { + subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + + ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil) + if err != nil { + t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + } + + return s, ep +} + +// TestHopLimitValidation is a test that makes sure that NDP packets are only +// received if their IP header's hop limit is set to 255. +func TestHopLimitValidation(t *testing.T) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + t.Helper() + + // Create a stack with the assigned link-local address lladdr0 + // and an endpoint to lladdr1. + s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) + + r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + } + + return s, ep, r + } + + handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) { + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: hopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + ep.HandlePacket(r, hdr.View().ToVectorisedView()) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }}, + {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }}, + {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }}, + {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }}, + {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }}, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + s, ep, r := setup(t) + defer r.Release() + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) + pkt := header.ICMPv6(hdr.Prepend(typ.size)) + pkt.SetType(typ.typ) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + // Should not have received any ICMPv6 packets with + // type = typ.typ. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Receive the NDP packet with an invalid hop limit + // value. + handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r) + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + + // Rx count of NDP packet of type typ.typ should not + // have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Receive the NDP packet with a valid hop limit value. + handleIPv6Payload(hdr, ndpHopLimit, ep, &r) + + // Rx count of NDP packet of type typ.typ should have + // increased. + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 989058413..11efb4e44 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 315780c0c..30cea8996 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -19,6 +19,7 @@ import ( "math" "math/rand" "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -27,6 +28,10 @@ const ( // FirstEphemeral is the first ephemeral port. FirstEphemeral = 16000 + // numEphemeralPorts it the mnumber of available ephemeral ports to + // Netstack. + numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1 + anyIPAddress tcpip.Address = "" ) @@ -40,6 +45,13 @@ type portDescriptor struct { type PortManager struct { mu sync.RWMutex allocatedPorts map[portDescriptor]bindAddresses + + // hint is used to pick ports ephemeral ports in a stable order for + // a given port offset. + // + // hint must be accessed using the portHint/incPortHint helpers. + // TODO(gvisor.dev/issue/940): S/R this field. + hint uint32 } type portNode struct { @@ -47,43 +59,76 @@ type portNode struct { refs int } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]portNode +// deviceNode is never empty. When it has no elements, it is removed from the +// map that references it. +type deviceNode map[tcpip.NICID]portNode -// isAvailable checks whether an IP address is available to bind to. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool { - if addr == anyIPAddress { - if len(b) == 0 { - return true - } +// isAvailable checks whether binding is possible by device. If not binding to a +// device, check against all portNodes. If binding to a specific device, check +// against the unspecified device and the provided device. +func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { + if bindToDevice == 0 { + // Trying to binding all devices. if !reuse { + // Can't bind because the (addr,port) is already bound. return false } - for _, n := range b { - if !n.reuse { + for _, p := range d { + if !p.reuse { + // Can't bind because the (addr,port) was previously bound without reuse. return false } } return true } - // If all addresses for this portDescriptor are already bound, no - // address is available. - if n, ok := b[anyIPAddress]; ok { - if !reuse { + if p, ok := d[0]; ok { + if !reuse || !p.reuse { return false } - if !n.reuse { + } + + if p, ok := d[bindToDevice]; ok { + if !reuse || !p.reuse { return false } } - if n, ok := b[addr]; ok { - if !reuse { + return true +} + +// bindAddresses is a set of IP addresses. +type bindAddresses map[tcpip.Address]deviceNode + +// isAvailable checks whether an IP address is available to bind to. If the +// address is the "any" address, check all other addresses. Otherwise, just +// check against the "any" address and the provided address. +func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { + if addr == anyIPAddress { + // If binding to the "any" address then check that there are no conflicts + // with all addresses. + for _, d := range b { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + return true + } + + // Check that there is no conflict with the "any" address. + if d, ok := b[anyIPAddress]; ok { + if !d.isAvailable(reuse, bindToDevice) { return false } - return n.reuse } + + // Check that this is no conflict with the provided address. + if d, ok := b[addr]; ok { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + return true } @@ -97,11 +142,40 @@ func NewPortManager() *PortManager { // is suitable for its needs, and stopping when a port is found or an error // occurs. func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { - count := uint16(math.MaxUint16 - FirstEphemeral + 1) - offset := uint16(rand.Int31n(int32(count))) + offset := uint32(rand.Int31n(numEphemeralPorts)) + return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) +} + +// portHint atomically reads and returns the s.hint value. +func (s *PortManager) portHint() uint32 { + return atomic.LoadUint32(&s.hint) +} + +// incPortHint atomically increments s.hint by 1. +func (s *PortManager) incPortHint() { + atomic.AddUint32(&s.hint, 1) +} - for i := uint16(0); i < count; i++ { - port = FirstEphemeral + (offset+i)%count +// PickEphemeralPortStable starts at the specified offset + s.portHint and +// iterates over all ephemeral ports, allowing the caller to decide whether a +// given port is suitable for its needs and stopping when a port is found or an +// error occurs. +func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) + if err == nil { + s.incPortHint() + } + return p, err + +} + +// pickEphemeralPort starts at the offset specified from the FirstEphemeral port +// and iterates over the number of ports specified by count and allows the +// caller to decide whether a given port is suitable for its needs, and stopping +// when a port is found or an error occurs. +func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + for i := uint32(0); i < count; i++ { + port = uint16(FirstEphemeral + (offset+i)%count) ok, err := testPort(port) if err != nil { return 0, err @@ -116,17 +190,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse) + return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse) { + if !addrs.isAvailable(addr, reuse, bindToDevice) { return false } } @@ -138,14 +212,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse) { + if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -153,13 +227,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil + return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { return false } @@ -171,11 +245,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m = make(bindAddresses) s.allocatedPorts[desc] = m } - if n, ok := m[addr]; ok { + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + if n, ok := d[bindToDevice]; ok { n.refs++ - m[addr] = n + d[bindToDevice] = n } else { - m[addr] = portNode{reuse: reuse, refs: 1} + d[bindToDevice] = portNode{reuse: reuse, refs: 1} } } @@ -184,22 +263,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { - n, ok := m[addr] + d, ok := m[addr] + if !ok { + continue + } + n, ok := d[bindToDevice] if !ok { continue } n.refs-- + d[bindToDevice] = n if n.refs == 0 { + delete(d, bindToDevice) + } + if len(d) == 0 { delete(m, addr) - } else { - m[addr] = n } if len(m) == 0 { delete(s.allocatedPorts, desc) diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 689401661..19f4833fc 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -15,6 +15,7 @@ package ports import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -34,6 +35,7 @@ type portReserveTestAction struct { want *tcpip.Error reuse bool release bool + device tcpip.NICID } func TestPortReservation(t *testing.T) { @@ -100,6 +102,112 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: anyIPAddress, release: true}, {port: 24, ip: anyIPAddress, reuse: false, want: nil}, }, + }, { + tname: "bind twice with device fails", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 3, want: nil}, + {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 1, want: nil}, + {port: 24, ip: fakeIPAddress, device: 2, want: nil}, + }, + }, { + tname: "bind to device and then without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + }, + }, { + tname: "binding with reuse and device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "mixing reuse and not reuse by binding to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: nil}, + }, + }, { + tname: "can't bind to 0 after mixing reuse and not reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind and release", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + + // Release the bind to device 0 and try again. + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + }, + }, { + tname: "bind twice with reuse once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "release an unreserved device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + // The below don't exist. + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + // Release all. + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -108,12 +216,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) @@ -125,7 +233,6 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { - pm := NewPortManager() customErr := &tcpip.Error{} for _, test := range []struct { name string @@ -169,9 +276,63 @@ func TestPickEphemeralPort(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { + pm := NewPortManager() if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) } }) } } + +func TestPickEphemeralPortStable(t *testing.T) { + customErr := &tcpip.Error{} + for _, test := range []struct { + name string + f func(port uint16) (bool, *tcpip.Error) + wantErr *tcpip.Error + wantPort uint16 + }{ + { + name: "no-port-available", + f: func(port uint16) (bool, *tcpip.Error) { + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + { + name: "port-tester-error", + f: func(port uint16) (bool, *tcpip.Error) { + return false, customErr + }, + wantErr: customErr, + }, + { + name: "only-port-16042-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port == FirstEphemeral+42 { + return true, nil + } + return false, nil + }, + wantPort: FirstEphemeral + 42, + }, + { + name: "only-port-under-16000-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port < FirstEphemeral { + return true, nil + } + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + } { + t.Run(test.name, func(t *testing.T) { + pm := NewPortManager() + portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) + if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { + t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + } + }) + } +} diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index e2021cd15..2239c1e66 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -126,7 +126,10 @@ func main() { // Create the stack with ipv4 and tcp protocols, then add a tun-based // NIC and ipv4 address. - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) mtu, err := rawfile.GetMTU(tunName) if err != nil { @@ -138,11 +141,11 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) + linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) if err != nil { log.Fatal(err) } - if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil { + if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 1716be285..bca73cbb1 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -111,7 +111,10 @@ func main() { // Create the stack with ip and tcp protocols, then add a tun-based // NIC and address. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) mtu, err := rawfile.GetMTU(tunName) if err != nil { @@ -128,7 +131,7 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{ + linkEP, err := fdbased.New(&fdbased.Options{ FDs: []int{fd}, MTU: mtu, EthernetHeader: *tap, @@ -137,7 +140,7 @@ func main() { if err != nil { log.Fatal(err) } - if err := s.CreateNIC(1, linkID); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index 76b5f4ffa..29b7d761c 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "seqnum", srcs = ["seqnum.go"], diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 9986b4be3..6a78432c9 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,11 +1,27 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +go_template_instance( + name = "linkaddrentry_list", + out = "linkaddrentry_list.go", + package = "stack", + prefix = "linkAddrEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*linkAddrEntry", + "Linker": "*linkAddrEntry", + }, +) go_library( name = "stack", srcs = [ + "icmp_rate_limit.go", "linkaddrcache.go", + "linkaddrentry_list.go", "nic.go", "registration.go", "route.go", @@ -19,6 +35,7 @@ go_library( ], deps = [ "//pkg/ilist", + "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", @@ -28,6 +45,7 @@ go_library( "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", + "@org_golang_x_time//rate:go_default_library", ], ) @@ -36,6 +54,7 @@ go_test( size = "small", srcs = [ "stack_test.go", + "transport_demuxer_test.go", "transport_test.go", ], deps = [ @@ -46,6 +65,9 @@ go_test( "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) @@ -60,3 +82,11 @@ go_test( "//pkg/tcpip", ], ) + +filegroup( + name = "autogen", + srcs = [ + "linkaddrentry_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go new file mode 100644 index 000000000..3a20839da --- /dev/null +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -0,0 +1,41 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "golang.org/x/time/rate" +) + +const ( + // icmpLimit is the default maximum number of ICMP messages permitted by this + // rate limiter. + icmpLimit = 1000 + + // icmpBurst is the default number of ICMP messages that can be sent in a single + // burst. + icmpBurst = 50 +) + +// ICMPRateLimiter is a global rate limiter that controls the generation of +// ICMP messages generated by the stack. +type ICMPRateLimiter struct { + *rate.Limiter +} + +// NewICMPRateLimiter returns a global rate limiter for controlling the rate +// at which ICMP messages are generated by the stack. +func NewICMPRateLimiter() *ICMPRateLimiter { + return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} +} diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 77bb0ccb9..267df60d1 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -42,10 +42,11 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - mu sync.Mutex - cache map[tcpip.FullAddress]*linkAddrEntry - next int // array index of next available entry - entries [linkAddrCacheSize]linkAddrEntry + cache struct { + sync.Mutex + table map[tcpip.FullAddress]*linkAddrEntry + lru linkAddrEntryList + } } // entryState controls the state of a single entry in the cache. @@ -60,9 +61,6 @@ const ( // failed means that address resolution timed out and the address // could not be resolved. failed - // expired means that the cache entry has expired and the address must be - // resolved again. - expired ) // String implements Stringer. @@ -74,8 +72,6 @@ func (s entryState) String() string { return "ready" case failed: return "failed" - case expired: - return "expired" default: return fmt.Sprintf("unknown(%d)", s) } @@ -84,64 +80,46 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + linkAddrEntryEntry + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of 'incomplete' these waiters are notified. + // state transitions out of incomplete these waiters are notified. wakers map[*sleep.Waker]struct{} + // done is used to allow callers to wait on address resolution. It is nil iff + // s is incomplete and resolution is not yet in progress. done chan struct{} } -func (e *linkAddrEntry) state() entryState { - if e.s != expired && time.Now().After(e.expiration) { - // Force the transition to ensure waiters are notified. - e.changeState(expired) - } - return e.s -} - -func (e *linkAddrEntry) changeState(ns entryState) { - if e.s == ns { - return - } - - // Validate state transition. - switch e.s { - case incomplete: - // All transitions are valid. - case ready, failed: - if ns != expired { - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - } - case expired: - // Terminal state. - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - default: - panic(fmt.Sprintf("invalid state: %s", e.s)) - } - +// changeState sets the entry's state to ns, notifying any waiters. +// +// The entry's expiration is bumped up to the greater of itself and the passed +// expiration; the zero value indicates immediate expiration, and is set +// unconditionally - this is an implementation detail that allows for entries +// to be reused. +func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { // Notify whoever is waiting on address resolution when transitioning - // out of 'incomplete'. - if e.s == incomplete { + // out of incomplete. + if e.s == incomplete && ns != incomplete { for w := range e.wakers { w.Assert() } e.wakers = nil - if e.done != nil { - close(e.done) + if ch := e.done; ch != nil { + close(ch) } + e.done = nil } - e.s = ns -} -func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) { - if w != nil { - e.wakers[w] = struct{}{} + if expiration.IsZero() || expiration.After(e.expiration) { + e.expiration = expiration } + e.s = ns } func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { @@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] - if ok { - s := entry.state() - if s != expired && entry.linkAddr == v { - // Disregard repeated calls. - return - } - // Check if entry is waiting for address resolution. - if s == incomplete { - entry.linkAddr = v - } else { - // Otherwise create a new entry to replace it. - entry = c.makeAndAddEntry(k, v) - } - } else { - entry = c.makeAndAddEntry(k, v) - } + // Calculate expiration time before acquiring the lock, since expiration is + // relative to the time when information was learned, rather than when it + // happened to be inserted into the cache. + expiration := time.Now().Add(c.ageLimit) - entry.changeState(ready) + c.cache.Lock() + entry := c.getOrCreateEntryLocked(k) + entry.linkAddr = v + + entry.changeState(ready, expiration) + c.cache.Unlock() } -// makeAndAddEntry is a helper function to create and add a new -// entry to the cache map and evict older entry as needed. -func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry { - // Take over the next entry. - entry := &c.entries[c.next] - if c.cache[entry.addr] == entry { - delete(c.cache, entry.addr) +// getOrCreateEntryLocked retrieves a cache entry associated with k. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { + if entry, ok := c.cache.table[k]; ok { + c.cache.lru.Remove(entry) + c.cache.lru.PushFront(entry) + return entry } + var entry *linkAddrEntry + if len(c.cache.table) == linkAddrCacheSize { + entry = c.cache.lru.Back() - // Mark the soon-to-be-replaced entry as expired, just in case there is - // someone waiting for address resolution on it. - entry.changeState(expired) + delete(c.cache.table, entry.addr) + c.cache.lru.Remove(entry) - *entry = linkAddrEntry{ - addr: k, - linkAddr: v, - expiration: time.Now().Add(c.ageLimit), - wakers: make(map[*sleep.Waker]struct{}), - done: make(chan struct{}), + // Wake waiters and mark the soon-to-be-reused entry as expired. Note + // that the state passed doesn't matter when the zero time is passed. + entry.changeState(failed, time.Time{}) + } else { + entry = new(linkAddrEntry) } - c.cache[k] = entry - c.next = (c.next + 1) % len(c.entries) + *entry = linkAddrEntry{ + addr: k, + s: incomplete, + } + c.cache.table[k] = entry + c.cache.lru.PushFront(entry) return entry } @@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } } - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.cache[k]; ok { - switch s := entry.state(); s { - case expired: - case ready: - return entry.linkAddr, nil, nil - case failed: - return "", nil, tcpip.ErrNoLinkAddress - case incomplete: - // Address resolution is still in progress. - entry.maybeAddWaker(waker) - return "", entry.done, tcpip.ErrWouldBlock - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + c.cache.Lock() + defer c.cache.Unlock() + entry := c.getOrCreateEntryLocked(k) + switch s := entry.s; s { + case ready, failed: + if !time.Now().After(entry.expiration) { + // Not expired. + switch s { + case ready: + return entry.linkAddr, nil, nil + case failed: + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } - } - if linkRes == nil { - return "", nil, tcpip.ErrNoLinkAddress - } + entry.changeState(incomplete, time.Time{}) + fallthrough + case incomplete: + if waker != nil { + if entry.wakers == nil { + entry.wakers = make(map[*sleep.Waker]struct{}) + } + entry.wakers[waker] = struct{}{} + } - // Add 'incomplete' entry in the cache to mark that resolution is in progress. - e := c.makeAndAddEntry(k, "") - e.maybeAddWaker(waker) + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + } - go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + entry.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + } - return "", e.done, tcpip.ErrWouldBlock + return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } // removeWaker removes a waker previously added through get(). func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.mu.Lock() - defer c.mu.Unlock() + c.cache.Lock() + defer c.cache.Unlock() - if entry, ok := c.cache[k]; ok { + if entry, ok := c.cache.table[k]; ok { entry.removeWaker(waker) } } @@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) select { - case <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(k, i); stop { + case now := <-time.After(c.resolutionTimeout): + if stop := c.checkLinkRequest(now, k, i); stop { return } case <-done: @@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has succeeded // and mark the entry accordingly, e.g. ready, failed, etc. Return true if request // can stop, false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { + c.cache.Lock() + defer c.cache.Unlock() + entry, ok := c.cache.table[k] if !ok { // Entry was evicted from the cache. return true } - - switch s := entry.state(); s { - case ready, failed, expired: + switch s := entry.s; s { + case ready, failed: // Entry was made ready by resolver or failed. Either way we're done. - return true case incomplete: - if attempt+1 >= c.resolutionAttempts { - // Max number of retries reached, mark entry as failed. - entry.changeState(failed) - return true + if attempt+1 < c.resolutionAttempts { + // No response yet, need to send another ARP request. + return false } - // No response yet, need to send another ARP request. - return false + // Max number of retries reached, mark entry as failed. + entry.changeState(failed, now.Add(c.ageLimit)) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } + return true } func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - return &linkAddrCache{ + c := &linkAddrCache{ ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, - cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize), } + c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 924f4d240..9946b8fe8 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "sync" + "sync/atomic" "testing" "time" @@ -29,25 +30,34 @@ type testaddr struct { linkAddr tcpip.LinkAddress } -var testaddrs []testaddr +var testAddrs = func() []testaddr { + var addrs []testaddr + for i := 0; i < 4*linkAddrCacheSize; i++ { + addr := fmt.Sprintf("Addr%06d", i) + addrs = append(addrs, testaddr{ + addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + linkAddr: tcpip.LinkAddress("Link" + addr), + }) + } + return addrs +}() type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration + cache *linkAddrCache + delay time.Duration + onLinkAddressRequest func() } func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - go func() { - if r.delay > 0 { - time.Sleep(r.delay) - } - r.fakeRequest(addr) - }() + time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) + if f := r.onLinkAddressRequest; f != nil { + f() + } return nil } func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testaddrs { + for _, ta := range testAddrs { if ta.addr.Addr == addr { r.cache.add(ta.addr, ta.linkAddr) break @@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } -func init() { - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - testaddrs = append(testaddrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } -} - func TestCacheOverflow(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testaddrs) - 1; i >= 0; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= 0; i-- { + e := testAddrs[i] c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { @@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) { } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { - e := testaddrs[i] + e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) @@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. - for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { + e := testAddrs[i] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) } @@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) { for r := 0; r < 16; r++ { wg.Add(1) go func() { - for _, e := range testaddrs { + for _, e := range testAddrs { c.add(e.addr, e.linkAddr) c.get(e.addr, nil, "", nil, nil) // make work for gotsan } @@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } - e = testaddrs[0] + e = testAddrs[0] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } @@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) { func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { @@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) { func TestCacheReplace(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) @@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) { func TestCacheResolution(t *testing.T) { c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testaddrs { + for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) @@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) { c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} + var requestCount uint32 + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) + } + // First, sanity check that resolution is working... - e := testaddrs[0] + e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } + before := atomic.LoadUint32(&requestCount) + e.addr.Addr += "2" if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } + + if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } } func TestCacheResolutionTimeout(t *testing.T) { @@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) { c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - e := testaddrs[0] + e := testAddrs[0] if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4ef85bdfb..f64bbf6eb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -34,15 +34,13 @@ type NIC struct { linkEP LinkEndpoint loopback bool - demux *transportDemuxer - - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - subnets []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mu sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber]*ilist.List + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback name: name, linkEP: ep, loopback: loopback, - demux: newTransportDemuxer(stack), primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -102,6 +99,35 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Create an endpoint to receive broadcast packets on this interface. + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + if err := n.AddAddress(tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, + }, NeverPrimaryEndpoint); err != nil { + return err + } + } + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + } + + return nil +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -129,37 +155,6 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { - n.mu.RLock() - defer n.mu.RUnlock() - - var r *referencedNetworkEndpoint - - // Check for a primary endpoint. - if list, ok := n.primary[protocol]; ok { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) - if ref.holdsInsertRef && ref.tryIncRef() { - r = ref - break - } - } - - } - - if r == nil { - return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress - } - - addressWithPrefix := tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } - r.decRef() - - return addressWithPrefix, nil -} - // primaryEndpoint returns the primary endpoint of n for the given network // protocol. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { @@ -178,7 +173,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN case header.IPv4Broadcast, header.IPv4Any: continue } - if r.tryIncRef() { + if r.isValidForOutgoing() && r.tryIncRef() { return r } } @@ -197,22 +192,44 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // getRefEpOrCreateTemp returns the referenced network endpoint for the given // protocol and address. If none exists a temporary one may be created if -// requested. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint { +// we are in promiscuous mode or spoofing. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref + if ref, ok := n.endpoints[id]; ok { + // An endpoint with this id exists, check if it can be used and return it. + switch ref.getKind() { + case permanentExpired: + if !spoofingOrPromiscuous { + n.mu.RUnlock() + return nil + } + fallthrough + case temporary, permanent: + if ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + } } - // The address was not found, create a temporary one if requested by the - // caller or if the address is found in the NIC's subnets. - createTempEP := allowTemp + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.subnets { + for _, sn := range n.addressRanges { + // Skip the subnet address. + if address == sn.ID() { + continue + } + // For now just skip the broadcast address, until we support it. + // FIXME(b/137608825): Add support for sending/receiving directed + // (subnet) broadcast. + if address == sn.Broadcast() { + continue + } if sn.Contains(address) { createTempEP = true break @@ -230,34 +247,70 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref + if ref, ok := n.endpoints[id]; ok { + // No need to check the type as we are ok with expired endpoints at this + // point. + if ref.tryIncRef() { + n.mu.Unlock() + return ref + } + // tryIncRef failing means the endpoint is scheduled to be removed once the + // lock is released. Remove it here so we can create a new (temporary) one. + // The removal logic waiting for the lock handles this case. + n.removeEndpointLocked(ref) } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { n.mu.Unlock() return nil } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, true) - - if ref != nil { - ref.holdsInsertRef = false - } + }, peb, temporary) n.mu.Unlock() return ref } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if ref, ok := n.endpoints[id]; ok { + switch ref.getKind() { + case permanent: + // The NIC already have a permanent endpoint with that address. + return nil, tcpip.ErrDuplicateAddress + case permanentExpired, temporary: + // Promote the endpoint to become permanent. + if ref.tryIncRef() { + ref.setKind(permanent) + return ref, nil + } + // tryIncRef failing means the endpoint is scheduled to be removed once + // the lock is released. Remove it here so we can create a new + // (permanent) one. The removal logic waiting for the lock handles this + // case. + n.removeEndpointLocked(ref) + } + } + return n.addAddressLocked(protocolAddress, peb, permanent) +} + +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP address before adding them. + + // Sanity check. + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if _, ok := n.endpoints[id]; ok { + // Endpoint already exists. + return nil, tcpip.ErrDuplicateAddress + } + netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol @@ -268,22 +321,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if err != nil { return nil, err } - - id := *ep.ID() - if ref, ok := n.endpoints[id]; ok { - if !replace { - return nil, tcpip.ErrDuplicateAddress - } - - n.removeEndpointLocked(ref) - } - ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - holdsInsertRef: true, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, } // Set up cache if link address resolution exists for this protocol. @@ -293,6 +336,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 unicast address, join the solicited-node + // multicast address. + if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -316,18 +368,26 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, false) + _, err := n.addPermanentAddressLocked(protocolAddress, peb) n.mu.Unlock() return err } -// Addresses returns the addresses associated with this NIC. -func (n *NIC) Addresses() []tcpip.ProtocolAddress { +// AllAddresses returns all addresses (primary and non-primary) associated with +// this NIC. +func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { + // Don't include expired or temporary endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -339,45 +399,66 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } -// AddSubnet adds a new subnet to n, so that it starts accepting packets -// targeted at the given address and network protocol. -func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { +// PrimaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + + var addrs []tcpip.ProtocolAddress + for proto, list := range n.primary { + for e := list.Front(); e != nil; e = e.Next() { + ref := e.(*referencedNetworkEndpoint) + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } + + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }, + }) + } + } + return addrs +} + +// AddAddressRange adds a range of addresses to n, so that it starts accepting +// packets targeted at the given addresses and network protocol. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.subnets = append(n.subnets, subnet) + n.addressRanges = append(n.addressRanges, subnet) n.mu.Unlock() } -// RemoveSubnet removes the given subnet from n. -func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { +// RemoveAddressRange removes the given address range from n. +func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.subnets[:0] - for _, sub := range n.subnets { + tmp := n.addressRanges[:0] + for _, sub := range n.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.subnets = tmp + n.addressRanges = tmp n.mu.Unlock() } -// ContainsSubnet reports whether this NIC contains the given subnet. -func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { - for _, s := range n.Subnets() { - if s == subnet { - return true - } - } - return false -} - // Subnets returns the Subnets associated with this NIC. -func (n *NIC) Subnets() []tcpip.Subnet { +func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) for nid := range n.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { @@ -387,19 +468,22 @@ func (n *NIC) Subnets() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.subnets...) + return append(sns, n.addressRanges...) } func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() - // Nothing to do if the reference has already been replaced with a - // different one. + // Nothing to do if the reference has already been replaced with a different + // one. This happens in the case where 1) this endpoint's ref count hit zero + // and was waiting (on the lock) to be removed and 2) the same address was + // re-added in the meantime by removing this endpoint from the list and + // adding a new one. if n.endpoints[id] != r { return } - if r.holdsInsertRef { + if r.getKind() == permanent { panic("Reference count dropped to zero before being removed") } @@ -418,15 +502,28 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || !r.holdsInsertRef { +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } - r.holdsInsertRef = false + r.setKind(permanentExpired) + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } - r.decRefLocked() + // At this point the endpoint is deleted. + + // If we are removing an IPv6 unicast address, leave the solicited-node + // multicast address. + if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -435,7 +532,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.removeAddressLocked(addr) + return n.removePermanentAddressLocked(addr) } // joinGroup adds a new endpoint for the given multicast address, if none @@ -444,6 +541,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -451,13 +555,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address if !ok { return tcpip.ErrUnknownProtocol } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint, false); err != nil { + }, NeverPrimaryEndpoint); err != nil { return err } } @@ -471,6 +575,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -479,7 +590,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress case 1: // This is the last one, clean up. - if err := n.removeAddressLocked(addr); err != nil { + if err := n.removePermanentAddressLocked(addr); err != nil { return err } } @@ -487,6 +598,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return nil } +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { + r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) + r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() +} + // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the physical interface. @@ -514,29 +632,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr src, dst := netProto.ParseAddresses(vv.First()) - // If the packet is destined to the IPv4 Broadcast address, then make a - // route to each IPv4 network endpoint and let each endpoint handle the - // packet. - if dst == header.IPv4Broadcast { - // n.endpoints is mutex protected so acquire lock. - n.mu.RLock() - for _, ref := range n.endpoints { - if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() - } - } - n.mu.RUnlock() - return - } - if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return } @@ -559,8 +656,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n := r.ref.nic n.mu.RLock() ref, ok := n.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() n.mu.RUnlock() - if ok && ref.tryIncRef() { + if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, vv) @@ -599,9 +697,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) { - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) - } + n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) if len(vv.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() @@ -615,9 +711,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { - return - } if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } @@ -631,7 +724,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { n.stack.stats.MalformedRcvdPackets.Increment() } } @@ -659,10 +752,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { - return - } - if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { return } } @@ -672,9 +762,38 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Stack returns the instance of the Stack that owns this NIC. +func (n *NIC) Stack() *Stack { + return n.stack +} + +type networkEndpointKind int32 + +const ( + // A permanent endpoint is created by adding a permanent address (vs. a + // temporary one) to the NIC. Its reference count is biased by 1 to avoid + // removal when no route holds a reference to it. It is removed by explicitly + // removing the permanent address from the NIC. + permanent networkEndpointKind = iota + + // An expired permanent endoint is a permanent endoint that had its address + // removed from the NIC, and it is waiting to be removed once no more routes + // hold a reference to it. This is achieved by decreasing its reference count + // by 1. If its address is re-added before the endpoint is removed, its type + // changes back to permanent and its reference count increases by 1 again. + permanentExpired + + // A temporary endpoint is created for spoofing outgoing packets, or when in + // promiscuous mode and accepting incoming packets that don't match any + // permanent endpoint. Its reference count is not biased by 1 and the + // endpoint is removed immediately when no more route holds a reference to + // it. A temporary endpoint can be promoted to permanent if its address + // is added permanently. + temporary +) + type referencedNetworkEndpoint struct { ilist.Entry - refs int32 ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -683,11 +802,34 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool + // refs is counting references held for this endpoint. When refs hits zero it + // triggers the automatic removal of the endpoint from the NIC. + refs int32 + + // networkEndpointKind must only be accessed using {get,set}Kind(). + kind networkEndpointKind +} + +func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { + return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) +} + +func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { + atomic.StoreInt32((*int32)(&r.kind), int32(kind)) +} + +// isValidForOutgoing returns true if the endpoint can be used to send out a +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in spoofing mode. +func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { + return r.getKind() != permanentExpired || r.nic.spoofing +} + +// isValidForIncoming returns true if the endpoint can accept an incoming +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in promiscuous mode. +func (r *referencedNetworkEndpoint) isValidForIncoming() bool { + return r.getKind() != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches @@ -699,11 +841,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is @@ -728,3 +873,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool { } } } + +// stack returns the Stack instance that owns the underlying endpoint. +func (r *referencedNetworkEndpoint) stack() *Stack { + return r.nic.stack +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2037eef9f..9d6157f22 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -109,7 +107,7 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -148,6 +146,19 @@ const ( PacketLoop ) +// NetworkHeaderParams are the header parameters given as input by the +// transport endpoint to the network. +type NetworkHeaderParams struct { + // Protocol refers to the transport protocol number. + Protocol tcpip.TransportProtocolNumber + + // TTL refers to Time To Live field of the IP-header. + TTL uint8 + + // TOS refers to TypeOfService or TrafficClass field of the IP-header. + TOS uint8 +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { @@ -172,7 +183,7 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error + WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. @@ -297,6 +308,15 @@ type LinkEndpoint interface { // IsAttached returns whether a NetworkDispatcher is attached to the // endpoint. IsAttached() bool + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // For now, requesting that an endpoint's worker goroutine(s) stop is + // implementation specific. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -359,14 +379,6 @@ type LinkAddressCache interface { RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// TransportProtocolFactory functions are used by the stack to instantiate -// transport protocols. -type TransportProtocolFactory func() TransportProtocol - -// NetworkProtocolFactory provides methods to be used by the stack to -// instantiate network protocols. -type NetworkProtocolFactory func() NetworkProtocol - // UnassociatedEndpointFactory produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can be used // to write arbitrary packets that include the IP header. @@ -374,60 +386,6 @@ type UnassociatedEndpointFactory interface { NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } -var ( - transportProtocols = make(map[string]TransportProtocolFactory) - networkProtocols = make(map[string]NetworkProtocolFactory) - - unassociatedFactory UnassociatedEndpointFactory - - linkEPMu sync.RWMutex - nextLinkEndpointID tcpip.LinkEndpointID = 1 - linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) -) - -// RegisterTransportProtocolFactory registers a new transport protocol factory -// with the stack so that it becomes available to users of the stack. This -// function is intended to be called by init() functions of the protocols. -func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) { - transportProtocols[name] = p -} - -// RegisterNetworkProtocolFactory registers a new network protocol factory with -// the stack so that it becomes available to users of the stack. This function -// is intended to be called by init() functions of the protocols. -func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { - networkProtocols[name] = p -} - -// RegisterUnassociatedFactory registers a factory to produce endpoints not -// associated with any particular transport protocol. This function is intended -// to be called by init() functions of the protocols. -func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { - unassociatedFactory = f -} - -// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an -// ID that can be used to refer to it. -func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { - linkEPMu.Lock() - defer linkEPMu.Unlock() - - v := nextLinkEndpointID - nextLinkEndpointID++ - - linkEndpoints[v] = linkEP - - return v -} - -// FindLinkEndpoint finds the link endpoint associated with the given ID. -func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { - linkEPMu.RLock() - defer linkEPMu.RUnlock() - - return linkEndpoints[id] -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 391ab4344..e72373964 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop = PacketLoop } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop } return Route{ @@ -148,12 +150,16 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. func (r *Route) IsResolutionRequired() bool { - return r.ref.linkCache != nil && r.RemoteLinkAddress == "" + return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { - err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) +func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + + err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { @@ -166,6 +172,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err @@ -200,12 +210,24 @@ func (r *Route) Clone() Route { return *r } -// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +// MakeLoopedRoute duplicates the given route with special handling for routes +// used for sending multicast or broadcast packets. In those cases the +// multicast/broadcast address is the remote address when sending out, but for +// incoming (looped) packets it becomes the local address. Similarly, the local +// interface address that was the local address going out becomes the remote +// address coming in. This is different to unicast routes where local and +// remote addresses remain the same as they identify location (local vs remote) +// not direction (source vs destination). func (r *Route) MakeLoopedRoute() Route { l := r.Clone() - if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress l.RemoteLinkAddress = l.LocalLinkAddress } return l } + +// Stack returns the instance of the Stack that owns this route. +func (r *Route) Stack() *Stack { + return r.ref.stack() +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d69162ba1..f67975525 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -17,17 +17,15 @@ // // For consumers, the only function of interest is New(), everything else is // provided by the tcpip/public package. -// -// For protocol implementers, RegisterTransportProtocolFactory() and -// RegisterNetworkProtocolFactory() are used to register protocol factories with -// the stack, which will then be used to instantiate protocol objects when -// consumers interact with the stack. package stack import ( + "encoding/binary" "sync" "time" + "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -45,6 +43,9 @@ const ( resolutionTimeout = 1 * time.Second // resolutionAttempts is set to the same ARP retries used in Linux. resolutionAttempts = 3 + + // DefaultTOS is the default type of service value for network endpoints. + DefaultTOS = 0 ) type transportProtocolState struct { @@ -350,6 +351,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // unassociatedFactory creates unassociated endpoints. If nil, raw + // endpoints are disabled. It is set during Stack creation and is + // immutable. unassociatedFactory UnassociatedEndpointFactory demux *transportDemuxer @@ -358,10 +362,6 @@ type Stack struct { linkAddrCache *linkAddrCache - // raw indicates whether raw sockets may be created. It is set during - // Stack creation and is immutable. - raw bool - mu sync.RWMutex nics map[tcpip.NICID]*NIC forwarding bool @@ -389,10 +389,26 @@ type Stack struct { // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. resumableEndpoints []ResumableEndpoint + + // icmpRateLimiter is a global rate limiter for all ICMP messages generated + // by the stack. + icmpRateLimiter *ICMPRateLimiter + + // portSeed is a one-time random value initialized at stack startup + // and is used to seed the TCP port picking on active connections + // + // TODO(gvisor.dev/issue/940): S/R this field. + portSeed uint32 } // Options contains optional Stack configuration. type Options struct { + // NetworkProtocols lists the network protocols to enable. + NetworkProtocols []NetworkProtocol + + // TransportProtocols lists the transport protocols to enable. + TransportProtocols []TransportProtocol + // Clock is an optional clock source used for timestampping packets. // // If no Clock is specified, the clock source will be time.Now. @@ -406,10 +422,39 @@ type Options struct { // stack (false). HandleLocal bool - // Raw indicates whether raw sockets may be created. - Raw bool + // UnassociatedFactory produces unassociated endpoints raw endpoints. + // Raw endpoints are enabled only if this is non-nil. + UnassociatedFactory UnassociatedEndpointFactory } +// TransportEndpointInfo holds useful information about a transport endpoint +// which can be queried by monitoring tools. +// +// +stateify savable +type TransportEndpointInfo struct { + // The following fields are initialized at creation time and are + // immutable. + + NetProto tcpip.NetworkProtocolNumber + TransProto tcpip.TransportProtocolNumber + + // The following fields are protected by endpoint mu. + + ID TransportEndpointID + // BindNICID and bindAddr are set via calls to Bind(). They are used to + // reject attempts to send data or connect via a different NIC or + // address + BindNICID tcpip.NICID + BindAddr tcpip.Address + // RegisterNICID is the default NICID registered as a side-effect of + // connect or datagram write. + RegisterNICID tcpip.NICID +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*TransportEndpointInfo) IsEndpointInfo() {} + // New allocates a new networking stack with only the requested networking and // transport protocols configured with default options. // @@ -417,7 +462,7 @@ type Options struct { // SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the // stack. Please refer to individual protocol implementations as to what options // are supported. -func New(network []string, transport []string, opts Options) *Stack { +func New(opts Options) *Stack { clock := opts.Clock if clock == nil { clock = &tcpip.StdClock{} @@ -433,16 +478,12 @@ func New(network []string, transport []string, opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - raw: opts.Raw, + icmpRateLimiter: NewICMPRateLimiter(), + portSeed: generateRandUint32(), } // Add specified network protocols. - for _, name := range network { - netProtoFactory, ok := networkProtocols[name] - if !ok { - continue - } - netProto := netProtoFactory() + for _, netProto := range opts.NetworkProtocols { s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -450,18 +491,14 @@ func New(network []string, transport []string, opts Options) *Stack { } // Add specified transport protocols. - for _, name := range transport { - transProtoFactory, ok := transportProtocols[name] - if !ok { - continue - } - transProto := transProtoFactory() + for _, transProto := range opts.TransportProtocols { s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } } - s.unassociatedFactory = unassociatedFactory + // Add the factory for unassociated endpoints, if present. + s.unassociatedFactory = opts.UnassociatedFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -596,7 +633,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if !s.raw { + if s.unassociatedFactory == nil { return nil, tcpip.ErrNotPermitted } @@ -614,12 +651,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { - ep := FindLinkEndpoint(linkEP) - if ep == nil { - return tcpip.ErrBadLinkEndpoint - } - +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -632,40 +664,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint s.nics[id] = n if enabled { - n.attachLinkEndpoint() + return n.enable() } return nil } // CreateNIC creates a NIC with the provided id and link-layer endpoint. -func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true, false) +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, false) +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, false) } // CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer // endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, true) +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false, false) +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false, false) +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -679,9 +711,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - nic.attachLinkEndpoint() - - return nil + return nic.enable() } // CheckNIC checks if a NIC is usable. @@ -696,14 +726,14 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // NICSubnets returns a map of NICIDs to their associated subnets. -func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet { +func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() nics := map[tcpip.NICID][]tcpip.Subnet{} for id, nic := range s.nics { - nics[id] = append(nics[id], nic.Subnets()...) + nics[id] = append(nics[id], nic.AddressRanges()...) } return nics } @@ -739,7 +769,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.Addresses(), + ProtocolAddresses: nic.PrimaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -804,71 +834,79 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return nic.AddAddress(protocolAddress, peb) } -// AddSubnet adds a subnet range to the specified NIC. -func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { +// AddAddressRange adds a range of addresses to the specified NIC. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.AddSubnet(protocol, subnet) + nic.AddAddressRange(protocol, subnet) return nil } return tcpip.ErrUnknownNICID } -// RemoveSubnet removes the subnet range from the specified NIC. -func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { +// RemoveAddressRange removes the range of addresses from the specified NIC. +func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.RemoveSubnet(subnet) + nic.RemoveAddressRange(subnet) return nil } return tcpip.ErrUnknownNICID } -// ContainsSubnet reports whether the specified NIC contains the specified -// subnet. -func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) { +// RemoveAddress removes an existing network-layer address from the specified +// NIC. +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - return nic.ContainsSubnet(subnet), nil + return nic.RemoveAddress(addr) } - return false, tcpip.ErrUnknownNICID + return tcpip.ErrUnknownNICID } -// RemoveAddress removes an existing network-layer address from the specified -// NIC. -func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { +// AllAddresses returns a map of NICIDs to their protocol addresses (primary +// and non-primary). +func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.RemoveAddress(addr) + nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) + for id, nic := range s.nics { + nics[id] = nic.AllAddresses() } - - return tcpip.ErrUnknownNICID + return nics } -// GetMainNICAddress returns the first primary address (and the subnet that -// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's -// address if no primary addresses exist. Returns an error if the NIC doesn't -// exist or has no endpoints. +// GetMainNICAddress returns the first primary address and prefix for the given +// NIC and protocol. Returns an error if the NIC doesn't exist and an empty +// value if the NIC doesn't have a primary address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.getMainNICAddress(protocol) + nic, ok := s.nics[id] + if !ok { + return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + for _, a := range nic.PrimaryAddresses() { + if a.Protocol == protocol { + return a.AddressWithPrefix, nil + } + } + return tcpip.AddressWithPrefix{}, nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { @@ -895,7 +933,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { @@ -1035,73 +1073,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { - if nicID == 0 { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { - if nicID == 0 { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterEndpoint(netProtos, protocol, id, ep) - } +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { - if nicID == 0 { - return s.demux.registerRawEndpoint(netProto, transProto, ep) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerRawEndpoint(netProto, transProto, ep) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { - if nicID == 0 { - s.demux.unregisterRawEndpoint(netProto, transProto, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterRawEndpoint(netProto, transProto, ep) - } + s.demux.unregisterRawEndpoint(netProto, transProto, ep) } // RegisterRestoredEndpoint records e as an endpoint that has been restored on @@ -1215,3 +1207,49 @@ func (s *Stack) IPTables() iptables.IPTables { func (s *Stack) SetIPTables(ipt iptables.IPTables) { s.tables = ipt } + +// ICMPLimit returns the maximum number of ICMP messages that can be sent +// in one second. +func (s *Stack) ICMPLimit() rate.Limit { + return s.icmpRateLimiter.Limit() +} + +// SetICMPLimit sets the maximum number of ICMP messages that be sent +// in one second. +func (s *Stack) SetICMPLimit(newLimit rate.Limit) { + s.icmpRateLimiter.SetLimit(newLimit) +} + +// ICMPBurst returns the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) ICMPBurst() int { + return s.icmpRateLimiter.Burst() +} + +// SetICMPBurst sets the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) SetICMPBurst(burst int) { + s.icmpRateLimiter.SetBurst(burst) +} + +// AllowICMPMessage returns true if we the rate limiter allows at least one +// ICMP message to be sent at this instant. +func (s *Stack) AllowICMPMessage() bool { + return s.icmpRateLimiter.Allow() +} + +// PortSeed returns a 32 bit value that can be used as a seed value for port +// picking. +// +// NOTE: The seed is generated once during stack initialization only. +func (s *Stack) PortSeed() uint32 { + return s.portSeed +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return binary.LittleEndian.Uint32(b) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 137c6183e..10fd1065f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct { prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { @@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -116,10 +116,10 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto } func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() + return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -128,7 +128,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu b := hdr.Prepend(fakeNetHeaderLen) b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] - b[2] = byte(protocol) + b[2] = byte(params.Protocol) if loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(payload.Views())) @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return nil } - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) } func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -181,18 +181,22 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } +func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { + return f.packetCount[int(intfAddr)%len(f.packetCount)] +} + func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, - linkEP: linkEP, + ep: ep, }, nil } @@ -218,12 +222,18 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeNetFactory() stack.NetworkProtocol { + return &fakeNetworkProtocol{} +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -241,7 +251,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -251,7 +261,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -261,7 +271,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -270,7 +280,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -280,7 +290,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -289,16 +299,75 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatal("FindRoute failed:", err) + return err } defer r.Release() + return send(r, payload) +} +func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Error("WritePacket failed:", err) + return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}) +} + +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { + t.Helper() + ep.Drain() + if err := sendTo(s, addr, payload); err != nil { + t.Error("sendTo failed:", err) + } + if got, want := ep.Drain(), 1; got != want { + t.Errorf("sendTo packet count: got = %d, want %d", got, want) + } +} + +func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { + t.Helper() + ep.Drain() + if err := send(r, payload); err != nil { + t.Error("send failed:", err) + } + if got, want := ep.Drain(), 1; got != want { + t.Errorf("send packet count: got = %d, want %d", got, want) + } +} + +func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := send(r, payload); gotErr != wantErr { + t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := sendTo(s, addr, payload); gotErr != wantErr { + t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + 1 + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) +} + +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we do NOT expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) +} + +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { + t.Helper() + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + if got := fakeNet.PacketCount(localAddrByte); got != want { + t.Errorf("receive packet count: got = %d, want %d", got, want) } } @@ -306,9 +375,11 @@ func TestNetworkSend(t *testing.T) { // Create a stack with the fake network protocol, one nic, and one // address: 1. The route table sends all packets through the only // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -325,20 +396,19 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - sendTo(t, s, "\x03", nil) - if c := linkEP.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x03", ep, nil) } func TestNetworkSendMultiRoute(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -350,8 +420,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -382,18 +452,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - sendTo(t, s, "\x05", nil) - - if c := linkEP1.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x05", ep1, nil) // Send a packet to an even destination. - sendTo(t, s, "\x06", nil) - - if c := linkEP2.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x06", ep2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -424,10 +486,12 @@ func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -439,8 +503,8 @@ func TestRoutes(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -498,58 +562,71 @@ func TestRoutes(t *testing.T) { } func TestAddressRemoval(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) - // Remove the address, then check that packet doesn't get delivered - // anymore. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } } -func TestDelayedRemovalDueToRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +func TestAddressRemovalWithRouteHeld(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatal("CreateNIC failed:", err) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { @@ -558,58 +635,239 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - - // Get a route, check that packet is still deliverable. - r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 2 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) - // Remove the address, then check that packet is still deliverable - // because the route is keeping the address alive. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) - } + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } +} + +func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { + t.Helper() + info, ok := s.NICInfo()[nicid] + if !ok { + t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + } + if len(addr) == 0 { + // No address given, verify that there is no address assigned to the NIC. + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + } + } + return + } + // Address given, verify the address is assigned to the NIC and no other + // address is. + found := false + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber { + if a.AddressWithPrefix.Address == addr { + found = true + } else { + t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) + } + } + } + if !found { + t.Errorf("verify address: couldn't find %s on the NIC", addr) + } +} + +func TestEndpointExpiration(t *testing.T) { + const ( + localAddrByte byte = 0x01 + remoteAddr tcpip.Address = "\x03" + noAddr tcpip.Address = "" + nicid tcpip.NICID = 1 + ) + localAddr := tcpip.Address([]byte{localAddrByte}) + + for _, promiscuous := range []bool{true, false} { + for _, spoofing := range []bool{true, false} { + t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) + buf[0] = localAddrByte + + if promiscuous { + if err := s.SetPromiscuousMode(nicid, true); err != nil { + t.Fatal("SetPromiscuousMode failed:", err) + } + } + + if spoofing { + if err := s.SetSpoofing(nicid, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + } + + // 1. No Address yet, send should only work for spoofing, receive for + // promiscuous mode. + //----------------------- + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, ep, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, ep, nil) + } else { + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + } - // Release the route, then check that packet is not deliverable anymore. - r.Release() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) + // 2. Add Address, everything should work. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + + // 3. Remove the address, send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, ep, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, ep, nil) + } else { + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + } + + // 4. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + + // 5. Take a reference to the endpoint by getting a route. Verify that + // we can still send/receive, including sending using the route. + //----------------------- + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) + + // 6. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, ep, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + } + if spoofing { + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) + } else { + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + } + + // 7. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) + + // 8. Remove the route, sendTo/recv should still work. + //----------------------- + r.Release() + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + + // 9. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, ep, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, ep, nil) + } else { + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + } + }) + } } } func TestPromiscuousMode(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -627,22 +885,15 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + const localAddrByte byte = 0x01 + buf[0] = localAddrByte + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testRecv(t, fakeNet, localAddrByte, ep, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -655,25 +906,24 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } -func TestAddressSpoofing(t *testing.T) { - srcAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") +func TestSpoofingWithAddress(t *testing.T) { + localAddr := tcpip.Address("\x01") + nonExistentLocalAddr := tcpip.Address("\x02") + dstAddr := tcpip.Address("\x03") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } @@ -687,61 +937,217 @@ func TestAddressSpoofing(t *testing.T) { // With address spoofing disabled, FindRoute does not permit an address // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err == nil { + t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) + } + + // With address spoofing enabled, FindRoute permits any address to be used + // as the source. + if err := s.SetSpoofing(1, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) + } + // Sending a packet works. + testSendTo(t, s, dstAddr, ep, nil) + testSend(t, r, ep, nil) + + // FindRoute should also work with a local address that exists on the NIC. + r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != localAddr { + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) + } + // Sending a packet using the route works. + testSend(t, r, ep, nil) +} + +func TestSpoofingNoAddress(t *testing.T) { + nonExistentLocalAddr := tcpip.Address("\x01") + dstAddr := tcpip.Address("\x02") + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // With address spoofing disabled, FindRoute does not permit an address + // that was not added to the NIC to be used as the source. + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err == nil { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } + // Sending a packet fails. + testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. if err := s.SetSpoofing(1, true); err != nil { t.Fatal("SetSpoofing failed:", err) } - r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - if r.LocalAddress != srcAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } + // Sending a packet works. + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, ep, nil) } -func TestBroadcastNeedsNoRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +func verifyRoute(gotRoute, wantRoute stack.Route) error { + if gotRoute.LocalAddress != wantRoute.LocalAddress { + return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) + } + if gotRoute.RemoteAddress != wantRoute.RemoteAddress { + return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) + } + if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { + return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + } + if gotRoute.NextHop != wantRoute.NextHop { + return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) + } + return nil +} + +func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } - if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + if err := s.AddProtocolAddress(1, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if r.LocalAddress != header.IPv4Any { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any) + // If the NIC doesn't exist, it won't work. + if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { + t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } +} - if r.RemoteAddress != header.IPv4Broadcast { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast) +func TestOutgoingBroadcastWithRouteTable(t *testing.T) { + defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. + nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. + nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") + + // Create a new stack with two NICs. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + if err := s.CreateNIC(2, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { + t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + } + + // Set the initial route table. + rt := []tcpip.Route{ + {Destination: nic1Addr.Subnet(), NIC: 1}, + {Destination: nic2Addr.Subnet(), NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, + } + s.SetRouteTable(rt) + + // When an interface is given, the route for a broadcast goes through it. + r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + + // When an interface is not given, it consults the route table. + // 1. Case: Using the default route. + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) + } + + // 2. Case: Having an explicit route for broadcast will select that one. + rt = append( + []tcpip.Route{ + {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + }, + rt..., + ) + s.SetRouteTable(rt) + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -781,10 +1187,12 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, } { t.Run(tc.name, func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -835,12 +1243,14 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Set the subnet, then check that packet is delivered. -func TestSubnetAcceptsMatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +// Add a range of addresses, then check that a packet is delivered. +func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -856,29 +1266,59 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + testRecv(t, fakeNet, localAddrByte, ep, buf) +} + +func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { + t.Helper() + + // Loop over all addresses and check them. + numOfAddresses := 1 << uint(8-subnet.Prefix()) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + + addrBytes := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + addr := tcpip.Address(addrBytes) + wantNicID := nicID + // The subnet and broadcast addresses are skipped. + if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { + wantNicID = 0 + } + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) + } + addrBytes[0]++ + } + + // Trying the next address should always fail since it is outside the range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) } } -// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +// Set a range of addresses, then remove it again, and check at each step that +// CheckLocalAddress returns the correct NIC for each address or zero if not +// existent. func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -891,39 +1331,34 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { } subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) - } - // Loop over all subnet addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - addr := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) - } - addr[0]++ + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) + + if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - // Trying the next address should fail since it is outside the subnet range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) + testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) + + if err := s.RemoveAddressRange(nicID, subnet); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } + + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) } -// Set destination outside the subnet, then check it doesn't get delivered. -func TestSubnetRejectsNonmatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +// Set a range of addresses, then send a packet to a destination outside the +// range and then check it doesn't get delivered. +func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -939,23 +1374,23 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) - } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestNetworkOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{}, + }) // Try an unsupported network protocol. if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -994,44 +1429,53 @@ func TestNetworkOptions(t *testing.T) { } } -func TestSubnetAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { +func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { + ranges, ok := s.NICAddressRanges()[id] + if !ok { + return false + } + for _, r := range ranges { + if r == addrRange { + return true + } + } + return false +} + +func TestAddresRangeAddRemove(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } addr := tcpip.Address("\x01\x01\x01\x01") mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - subnet, err := tcpip.NewSubnet(addr, mask) + addrRange, err := tcpip.NewSubnet(addr, mask) if err != nil { t.Fatal("NewSubnet failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { + t.Fatal("AddAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if !contained { - t.Fatal("got s.ContainsSubnet(...) = false, want = true") + if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatal("RemoveSubnet failed:", err) + if err := s.RemoveAddressRange(1, addrRange); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } } @@ -1042,9 +1486,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } // Insert <canBe> primary and <never> never-primary addresses. @@ -1082,20 +1528,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } if len(primaryAddrAdded) == 0 { - // No primary addresses present, expect an error. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress) + // No primary addresses present. + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) } } else { - // At least one primary address was added, expect a valid - // address and prefixLen. - gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok { - t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded) + // At least one primary address was added, verify the returned + // address is in the list of primary addresses we added. + if _, ok := primaryAddrAdded[gotAddr]; !ok { + t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) } } }) @@ -1107,9 +1553,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1134,19 +1582,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { t.Fatal("GetMainNICAddress failed:", err) - } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix) + } + if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { + t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { t.Fatal("RemoveAddress failed:", err) } - // Check that we get an error after removal. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress) + // Check that we get no address after removal. + gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } }) } @@ -1161,8 +1615,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address { } func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { + t.Helper() + if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses)) + t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) } sort.Slice(gotAddresses, func(i, j int) bool { @@ -1182,9 +1638,11 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1201,15 +1659,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1233,15 +1693,17 @@ func TestAddProtocolAddress(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1262,15 +1724,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1297,15 +1761,17 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestNICStats(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { - t.Fatal("CreateNIC failed:", err) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) @@ -1321,7 +1787,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1332,10 +1798,12 @@ func TestNICStats(t *testing.T) { payload := buffer.NewView(10) // Write a packet out via the address for NIC 1 - sendTo(t, s, "\x01", payload) - want := uint64(linkEP1.Drain()) + if err := sendTo(s, "\x01", payload); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := uint64(ep1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) + t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { @@ -1346,19 +1814,21 @@ func TestNICStats(t *testing.T) { func TestNICForwarding(t *testing.T) { // Create a stack with the fake network protocol, two NICs, each with // an address. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) s.SetForwarding(true) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -1377,10 +1847,10 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) select { - case <-linkEP2.C: + case <-ep2.C: default: t.Fatal("Packet not forwarded") } @@ -1394,9 +1864,3 @@ func TestNICForwarding(t *testing.T) { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } - -func init() { - stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { - return &fakeNetworkProtocol{} - }) -} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index cf8a6d129..92267ce4d 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -35,25 +35,109 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]TransportEndpoint + endpoints map[TransportEndpointID]*endpointsByNic // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint } +type endpointsByNic struct { + mu sync.RWMutex + endpoints map[tcpip.NICID]*multiPortEndpoint + // seed is a random secret for a jenkins hash. + seed uint32 +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + } + + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints bound to the right device. + if isMulticastOrBroadcast(id.LocalAddress) { + mpep.handlePacketAll(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[n.ID()] + if !ok { + mpep, ok = epsByNic.endpoints[0] + } + if !ok { + return + } + + // TODO(eyalsoha): Why don't we look at id to see if this packet needs to + // broadcast like we are doing with handlePacket above? + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) +} + +// registerEndpoint returns true if it succeeds. It fails and returns +// false if ep already has an element with the same key. +func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + + if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { + // There was already a bind. + return multiPortEp.singleRegisterEndpoint(t, reusePort) + } + + // This is a new binding. + multiPortEp := &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.reuse = reusePort + epsByNic.endpoints[bindToDevice] = multiPortEp + return multiPortEp.singleRegisterEndpoint(t, reusePort) +} + +// unregisterEndpoint returns true if endpointsByNic has to be unregistered. +func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + return false + } + if multiPortEp.unregisterEndpoint(t) { + delete(epsByNic.endpoints, bindToDevice) + } + return len(epsByNic.endpoints) == 0 +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - e, ok := eps.endpoints[id] + epsByNic, ok := eps.endpoints[id] if !ok { return } - if multiPortEp, ok := e.(*multiPortEndpoint); ok { - if !multiPortEp.unregisterEndpoint(ep) { - return - } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return } delete(eps.endpoints, id) } @@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]TransportEndpoint), + endpoints: make(map[TransportEndpointID]*endpointsByNic), } } } @@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) return err } } @@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // multiPortEndpoint is a container for TransportEndpoints which are bound to -// the same pair of address and port. +// the same pair of address and port. endpointsArr always has at least one +// element. type multiPortEndpoint struct { mu sync.RWMutex endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int - // seed is a random secret for a jenkins hash. - seed uint32 + // reuse indicates if more than one endpoint is allowed. + reuse bool } // reciprocalScale scales a value into range [0, n). @@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { - ep.mu.RLock() - defer ep.mu.RUnlock() +func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { + if len(mpep.endpointsArr) == 1 { + return mpep.endpointsArr[0] + } payload := []byte{ byte(id.LocalPort), @@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd byte(id.RemotePort >> 8), } - h := jenkins.Sum32(ep.seed) + h := jenkins.Sum32(seed) h.Write(payload) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) - return ep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) + return mpep.endpointsArr[idx] } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast or multicast datagram, deliver the datagram to all - // endpoints managed by ep. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { - for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy. - if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) - break - } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + ep.mu.RLock() + for i, endpoint := range ep.endpointsArr { + // HandlePacket modifies vv, so each endpoint needs its own copy except for + // the final one. + if i == len(ep.endpointsArr)-1 { + endpoint.HandlePacket(r, id, vv) + break } - } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + vvCopy := buffer.NewView(vv.Size()) + copy(vvCopy, vv.ToView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) } + ep.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { - ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) -} - -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { +// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint +// list. The list might be empty already. +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - // A new endpoint is added into endpointsArr and its index there is - // saved in endpointsMap. This will allows to remove endpoint from - // the array fast. + if len(ep.endpointsArr) > 0 { + // If it was previously bound, we need to check if we can bind again. + if !ep.reuse || !reusePort { + return tcpip.ErrPortInUse + } + } + + // A new endpoint is added into endpointsArr and its index there is saved in + // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. @@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { return true } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { + // TODO(eyalsoha): Why? reusePort = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return nil + return tcpip.ErrUnknownProtocol } eps.mu.Lock() defer eps.mu.Unlock() - var multiPortEp *multiPortEndpoint - if _, ok := eps.endpoints[id]; ok { - if !reusePort { - return tcpip.ErrPortInUse - } - multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) - if !ok { - return tcpip.ErrPortInUse - } + if epsByNic, ok := eps.endpoints[id]; ok { + // There was already a binding. + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } - if reusePort { - if multiPortEp == nil { - multiPortEp = &multiPortEndpoint{} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.seed = rand.Uint32() - eps.endpoints[id] = multiPortEp - } - - multiPortEp.singleRegisterEndpoint(ep) - - return nil + // This is a new binding. + epsByNic := &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), } - eps.endpoints[id] = ep + eps.endpoints[id] = epsByNic - return nil + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep) + eps.unregisterEndpoint(id, ep, bindToDevice) } } } @@ -265,23 +338,14 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // If a sender bound to the Loopback interface sends a broadcast, - // that broadcast must not be delivered to the sender. - if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort { - return false - } - - // If the packet is a broadcast, then find all matching transport endpoints. - // Otherwise, try to find a single matching transport endpoint. - destEps := make([]TransportEndpoint, 0, 1) eps.mu.RLock() - if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { - for epID, endpoint := range eps.endpoints { - if epID.LocalPort == id.LocalPort { - destEps = append(destEps, endpoint) - } - } + // Determine which transport endpoint or endpoints to deliver this packet to. + // If the packet is a broadcast or multicast, then find all matching + // transport endpoints. + var destEps []*endpointsByNic + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, vv, id) } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { destEps = append(destEps, ep) } @@ -299,7 +363,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + ep.handlePacket(r, id, vv) } return true @@ -331,7 +395,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -348,15 +412,16 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, } // Deliver the packet. - ep.HandleControlPacket(id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, vv) return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the local address. @@ -364,7 +429,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the remote part. @@ -372,15 +437,24 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given +// id. +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { + return matchedEPs[0] + } return nil } @@ -418,3 +492,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } } } + +func isMulticastOrBroadcast(addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go new file mode 100644 index 000000000..210233dc0 --- /dev/null +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -0,0 +1,352 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testPort = 4096 +) + +type testContext struct { + t *testing.T + linkEPs map[string]*channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v6only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } +} + +// newDualTestContextMultiNic creates the testing context and also linkEpNames +// named NICs. +func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + linkEPs := make(map[string]*channel.Endpoint) + for i, linkEpName := range linkEpNames { + channelEP := channel.New(256, mtu, "") + nicid := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + linkEPs[linkEpName] = channelEP + + if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress IPv4 failed: %v", err) + } + + if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress IPv6 failed: %v", err) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEPs: linkEPs, + } +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) +} + +func TestTransportDemuxerRegister(t *testing.T) { + for _, test := range []struct { + name string + proto tcpip.NetworkProtocolNumber + want *tcpip.Error + }{ + {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"success", ipv4.ProtocolNumber, nil}, + } { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + } + }) + } +} + +// TestReuseBindToDevice injects varied packets on input devices and checks that +// the distribution of packets received matches expectations. +func TestDistribution(t *testing.T) { + type endpointSockopts struct { + reuse int + bindToDevice string + } + for _, test := range []struct { + name string + // endpoints will received the inject packets. + endpoints []endpointSockopts + // wantedDistribution is the wanted ratio of packets received on each + // endpoint for each NIC on which packets are injected. + wantedDistributions map[string][]float64 + }{ + { + "BindPortReuse", + // 5 endpoints that all have reuse set. + []endpointSockopts{ + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed evenly. + "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + }, + }, + { + "BindToDevice", + // 3 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{0, "dev0"}, + endpointSockopts{0, "dev1"}, + endpointSockopts{0, "dev2"}, + }, + map[string][]float64{ + // Injected packets on dev0 go only to the endpoint bound to dev0. + "dev0": []float64{1, 0, 0}, + // Injected packets on dev1 go only to the endpoint bound to dev1. + "dev1": []float64{0, 1, 0}, + // Injected packets on dev2 go only to the endpoint bound to dev2. + "dev2": []float64{0, 0, 1}, + }, + }, + { + "ReuseAndBindToDevice", + // 6 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed among endpoints bound to + // dev0. + "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + // Injected packets on dev1 get distributed among endpoints bound to + // dev1 or unbound. + "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + // Injected packets on dev999 go only to the unbound. + "dev999": []float64{0, 0, 0, 0, 0, 1}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + for device, wantedDistribution := range test.wantedDistributions { + t.Run(device, func(t *testing.T) { + var devices []string + for d := range test.wantedDistributions { + devices = append(devices, d) + } + c := newDualTestContextMultiNic(t, defaultMTU, devices) + defer c.cleanup() + + c.createV6Endpoint(false) + + eps := make(map[tcpip.Endpoint]int) + + pollChannel := make(chan tcpip.Endpoint) + for i, endpoint := range test.endpoints { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + eps[ep] = i + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(ep) + + defer ep.Close() + reusePortOption := tcpip.ReusePortOption(endpoint.reuse) + if err := ep.SetSockOpt(reusePortOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + } + bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) + if err := ep.SetSockOpt(bindToDeviceOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + } + if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + } + } + + npackets := 100000 + nports := 10000 + if got, want := len(test.endpoints), len(wantedDistribution); got != want { + t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) + } + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, + &headers{ + srcPort: testPort + port, + dstPort: stackPort}, + device) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled by the same + // socket. + if want, got := ports[port], ep; want != got { + t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) + } + } + } + + // Check that a packet distribution is as expected. + for ep, i := range eps { + wantedRatio := wantedDistribution[i] + wantedRecv := wantedRatio * float64(npackets) + actualRecv := stats[ep] + actualRatio := float64(stats[ep]) / float64(npackets) + // The deviation is less than 10%. + if math.Abs(actualRatio-wantedRatio) > 0.05 { + t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5335897f5..86c62be25 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -38,9 +38,8 @@ const ( // Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't // use it. type fakeTransportEndpoint struct { - id stack.TransportEndpointID + stack.TransportEndpointInfo stack *stack.Stack - netProto tcpip.NetworkProtocolNumber proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route @@ -49,8 +48,16 @@ type fakeTransportEndpoint struct { acceptQueue []fakeTransportEndpoint } -func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto} +func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { + return &f.TransportEndpointInfo +} + +func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { + return nil +} + +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} } func (f *fakeTransportEndpoint) Close() { @@ -65,17 +72,17 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { + if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil { return 0, nil, err } @@ -91,6 +98,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptInt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -121,8 +133,8 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() // Try to register so that we can start receiving packets. - f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) + f.ID.RemoteAddress = addr.Addr + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) if err != nil { return err } @@ -163,7 +175,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, + false, /* reuse */ + 0, /* bindtoDevice */ ); err != nil { return err } @@ -184,9 +197,11 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - id: id, - stack: f.stack, - netProto: f.netProto, + stack: f.stack, + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, proto: f.proto, peerAddr: r.RemoteAddress, route: r.Clone(), @@ -251,7 +266,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } @@ -277,10 +292,17 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeTransFactory() stack.TransportProtocol { + return &fakeTransportProtocol{} +} + func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -340,9 +362,12 @@ func TestTransportReceive(t *testing.T) { } func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -408,9 +433,12 @@ func TestTransportControlReceive(t *testing.T) { } func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -452,7 +480,10 @@ func TestTransportSend(t *testing.T) { } func TestTransportOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) // Try an unsupported transport protocol. if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -493,20 +524,23 @@ func TestTransportOptions(t *testing.T) { } func TestTransportForwarding(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { + ep1 := loopback.New() + if err := s.CreateNIC(1, ep1); err != nil { t.Fatalf("CreateNIC #1 failed: %v", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress #1 failed: %v", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatalf("CreateNIC #2 failed: %v", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -545,7 +579,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.Inject(fakeNetNumber, req.ToVectorisedView()) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -559,7 +593,7 @@ func TestTransportForwarding(t *testing.T) { var p channel.PacketInfo select { - case p = <-linkEP2.C: + case p = <-ep2.C: default: t.Fatal("Response packet not forwarded") } @@ -571,9 +605,3 @@ func TestTransportForwarding(t *testing.T) { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } - -func init() { - stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { - return &fakeTransportProtocol{} - }) -} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 043dd549b..678a94616 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -57,6 +57,9 @@ type Error struct { // String implements fmt.Stringer.String. func (e *Error) String() string { + if e == nil { + return "<nil>" + } return e.msg } @@ -219,6 +222,15 @@ func (s *Subnet) Mask() AddressMask { return s.mask } +// Broadcast returns the subnet's broadcast address. +func (s *Subnet) Broadcast() Address { + addr := []byte(s.address) + for i := range addr { + addr[i] |= ^s.mask[i] + } + return Address(addr) +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 @@ -252,31 +264,34 @@ type FullAddress struct { Port uint16 } -// Payload provides an interface around data that is being sent to an endpoint. -// This allows the endpoint to request the amount of data it needs based on -// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data. -type Payload interface { - // Get returns a slice containing exactly 'min(size, p.Size())' bytes. - Get(size int) ([]byte, *Error) - - // Size returns the payload size. - Size() int +// Payloader is an interface that provides data. +// +// This interface allows the endpoint to request the amount of data it needs +// based on internal buffers without exposing them. +type Payloader interface { + // FullPayload returns all available bytes. + FullPayload() ([]byte, *Error) + + // Payload returns a slice containing at most size bytes. + Payload(size int) ([]byte, *Error) } -// SlicePayload implements Payload on top of slices for convenience. +// SlicePayload implements Payloader for slices. +// +// This is typically used for tests. type SlicePayload []byte -// Get implements Payload. -func (s SlicePayload) Get(size int) ([]byte, *Error) { - if size > s.Size() { - size = s.Size() - } - return s[:size], nil +// FullPayload implements Payloader.FullPayload. +func (s SlicePayload) FullPayload() ([]byte, *Error) { + return s, nil } -// Size implements Payload. -func (s SlicePayload) Size() int { - return len(s) +// Payload implements Payloader.Payload. +func (s SlicePayload) Payload(size int) ([]byte, *Error) { + if size > len(s) { + size = len(s) + } + return s[:size], nil } // A ControlMessages contains socket control messages for IP sockets. @@ -329,7 +344,7 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // @@ -389,6 +404,10 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptInt sets a socket option, for simple cases where a value + // has the int type. + SetSockOptInt(opt SockOpt, v int) *Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error @@ -410,6 +429,26 @@ type Endpoint interface { // IPTables returns the iptables for this endpoint's stack. IPTables() (iptables.IPTables, error) + + // Info returns a copy to the transport endpoint info. + Info() EndpointInfo + + // Stats returns a reference to the endpoint stats. + Stats() EndpointStats +} + +// EndpointInfo is the interface implemented by each endpoint info struct. +type EndpointInfo interface { + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo + // marker interface. + IsEndpointInfo() +} + +// EndpointStats is the interface implemented by each endpoint stats struct. +type EndpointStats interface { + // IsEndpointStats is an empty method to implement the tcpip.EndpointStats + // marker interface. + IsEndpointStats() } // WriteOptions contains options for Endpoint.Write. @@ -423,16 +462,33 @@ type WriteOptions struct { // EndOfRecord has the same semantics as Linux's MSG_EOR. EndOfRecord bool + + // Atomic means that all data fetched from Payloader must be written to the + // endpoint. If Atomic is false, then data fetched from the Payloader may be + // discarded if available endpoint buffer space is unsufficient. + Atomic bool } // SockOpt represents socket options which values have the int type. type SockOpt int const ( - // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of - // unread bytes in the input buffer should be returned. + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption SockOpt = iota + // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the send buffer size option. + SendBufferSizeOption + + // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the receive buffer size option. + ReceiveBufferSizeOption + + // SendQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the output buffer should be returned. + SendQueueSizeOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -441,18 +497,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send -// buffer size option. -type SendBufferSizeOption int - -// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the -// receive buffer size option. -type ReceiveBufferSizeOption int - -// SendQueueSizeOption is used in GetSockOpt to specify that the number of -// unread bytes in the output buffer should be returned. -type SendQueueSizeOption int - // V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int @@ -474,6 +518,10 @@ type ReuseAddressOption int // to be bound to an identical socket address. type ReusePortOption int +// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets +// should bind only on a specific NIC. +type BindToDeviceOption string + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int @@ -525,6 +573,12 @@ type ModerateReceiveBufferOption bool // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. type MaxSegOption int +// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop +// limit value for unicast messages. The default is protocol specific. +// +// A zero value indicates the default. +type TTLOption uint8 + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -566,6 +620,18 @@ type OutOfBandInlineOption int // datagram sockets are allowed to send packets to a broadcast address. type BroadcastOption int +// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify +// a default TTL. +type DefaultTTLOption uint8 + +// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS +// for all subsequent outgoing IPv4 packets from the endpoint. +type IPv4TOSOption uint8 + +// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS +// for all subsequent outgoing IPv6 packets from the endpoint. +type IPv6TrafficClassOption uint8 + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -581,7 +647,7 @@ type Route struct { } // String implements the fmt.Stringer interface. -func (r *Route) String() string { +func (r Route) String() string { var out strings.Builder fmt.Fprintf(&out, "%s", r.Destination) if len(r.Gateway) > 0 { @@ -591,9 +657,6 @@ func (r *Route) String() string { return out.String() } -// LinkEndpointID represents a data link layer endpoint. -type LinkEndpointID uint64 - // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 @@ -720,6 +783,10 @@ type ICMPv4SentPacketStats struct { // Dropped is the total number of ICMPv4 packets dropped due to link // layer errors. Dropped *StatCounter + + // RateLimited is the total number of ICMPv6 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter } // ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats. @@ -738,6 +805,10 @@ type ICMPv6SentPacketStats struct { // Dropped is the total number of ICMPv6 packets dropped due to link // layer errors. Dropped *StatCounter + + // RateLimited is the total number of ICMPv6 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter } // ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats. @@ -790,6 +861,14 @@ type IPStats struct { // OutgoingPacketErrors is the total number of IP packets which failed // to write to a link-layer endpoint. OutgoingPacketErrors *StatCounter + + // MalformedPacketsReceived is the total number of IP Packets that were + // dropped due to the IP packet header failing validation checks. + MalformedPacketsReceived *StatCounter + + // MalformedFragmentsReceived is the total number of IP Fragments that were + // dropped due to the fragment failing validation checks. + MalformedFragmentsReceived *StatCounter } // TCPStats collects TCP-specific stats. @@ -836,6 +915,9 @@ type TCPStats struct { // SegmentsSent is the number of TCP segments sent. SegmentsSent *StatCounter + // SegmentSendErrors is the number of TCP segments failed to be sent. + SegmentSendErrors *StatCounter + // ResetsSent is the number of TCP resets sent. ResetsSent *StatCounter @@ -888,6 +970,9 @@ type UDPStats struct { // PacketsSent is the number of UDP datagrams sent via sendUDP. PacketsSent *StatCounter + + // PacketSendErrors is the number of datagrams failed to be sent. + PacketSendErrors *StatCounter } // Stats holds statistics about the networking stack. @@ -898,7 +983,7 @@ type Stats struct { // stack that were for an unknown or unsupported protocol. UnknownProtocolRcvdPackets *StatCounter - // MalformedRcvPackets is the number of packets received by the stack + // MalformedRcvdPackets is the number of packets received by the stack // that were deemed malformed. MalformedRcvdPackets *StatCounter @@ -918,18 +1003,95 @@ type Stats struct { UDP UDPStats } +// ReceiveErrors collects packet receive errors within transport endpoint. +type ReceiveErrors struct { + // ReceiveBufferOverflow is the number of received packets dropped + // due to the receive buffer being full. + ReceiveBufferOverflow StatCounter + + // MalformedPacketsReceived is the number of incoming packets + // dropped due to the packet header being in a malformed state. + MalformedPacketsReceived StatCounter + + // ClosedReceiver is the number of received packets dropped because + // of receiving endpoint state being closed. + ClosedReceiver StatCounter +} + +// SendErrors collects packet send errors within the transport layer for +// an endpoint. +type SendErrors struct { + // SendToNetworkFailed is the number of packets failed to be written to + // the network endpoint. + SendToNetworkFailed StatCounter + + // NoRoute is the number of times we failed to resolve IP route. + NoRoute StatCounter + + // NoLinkAddr is the number of times we failed to resolve ARP. + NoLinkAddr StatCounter +} + +// ReadErrors collects segment read errors from an endpoint read call. +type ReadErrors struct { + // ReadClosed is the number of received packet drops because the endpoint + // was shutdown for read. + ReadClosed StatCounter + + // InvalidEndpointState is the number of times we found the endpoint state + // to be unexpected. + InvalidEndpointState StatCounter +} + +// WriteErrors collects packet write errors from an endpoint write call. +type WriteErrors struct { + // WriteClosed is the number of packet drops because the endpoint + // was shutdown for write. + WriteClosed StatCounter + + // InvalidEndpointState is the number of times we found the endpoint state + // to be unexpected. + InvalidEndpointState StatCounter + + // InvalidArgs is the number of times invalid input arguments were + // provided for endpoint Write call. + InvalidArgs StatCounter +} + +// TransportEndpointStats collects statistics about the endpoint. +type TransportEndpointStats struct { + // PacketsReceived is the number of successful packet receives. + PacketsReceived StatCounter + + // PacketsSent is the number of successful packet sends. + PacketsSent StatCounter + + // ReceiveErrors collects packet receive errors within transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects packet read errors from an endpoint read call. + ReadErrors ReadErrors + + // SendErrors collects packet send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects packet write errors from an endpoint write call. + WriteErrors WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*TransportEndpointStats) IsEndpointStats() {} + func fillIn(v reflect.Value) { for i := 0; i < v.NumField(); i++ { v := v.Field(i) - switch v.Kind() { - case reflect.Ptr: - if s := v.Addr().Interface().(**StatCounter); *s == nil { - *s = &StatCounter{} + if s, ok := v.Addr().Interface().(**StatCounter); ok { + if *s == nil { + *s = new(StatCounter) } - case reflect.Struct: + } else { fillIn(v) - default: - panic(fmt.Sprintf("unexpected type %s", v.Type())) } } } @@ -940,6 +1102,26 @@ func (s Stats) FillIn() Stats { return s } +// Clone returns a copy of the TransportEndpointStats by atomically reading +// each field. +func (src *TransportEndpointStats) Clone() TransportEndpointStats { + var dst TransportEndpointStats + clone(reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src).Elem()) + return dst +} + +func clone(dst reflect.Value, src reflect.Value) { + for i := 0; i < dst.NumField(); i++ { + d := dst.Field(i) + s := src.Field(i) + if c, ok := s.Addr().Interface().(*StatCounter); ok { + d.Addr().Interface().(*StatCounter).IncrementBy(c.Value()) + } else { + clone(d, s) + } + } +} + // String implements the fmt.Stringer interface. func (a Address) String() string { switch len(a) { @@ -1065,6 +1247,47 @@ func (a AddressWithPrefix) String() string { return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) } +// Subnet converts the address and prefix into a Subnet value and returns it. +func (a AddressWithPrefix) Subnet() Subnet { + addrLen := len(a.Address) + if a.PrefixLen <= 0 { + return Subnet{ + address: Address(strings.Repeat("\x00", addrLen)), + mask: AddressMask(strings.Repeat("\x00", addrLen)), + } + } + if a.PrefixLen >= addrLen*8 { + return Subnet{ + address: a.Address, + mask: AddressMask(strings.Repeat("\xff", addrLen)), + } + } + + sa := make([]byte, addrLen) + sm := make([]byte, addrLen) + n := uint(a.PrefixLen) + for i := 0; i < addrLen; i++ { + if n >= 8 { + sa[i] = a.Address[i] + sm[i] = 0xff + n -= 8 + continue + } + sm[i] = ^byte(0xff >> n) + sa[i] = a.Address[i] & sm[i] + n = 0 + } + + // For extra caution, call NewSubnet rather than directly creating the Subnet + // value. If that fails it indicates a serious bug in this code, so panic is + // in order. + s, err := NewSubnet(Address(sa), AddressMask(sm)) + if err != nil { + panic("invalid subnet: " + err.Error()) + } + return s +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index fb3a0a5ee..8c0aacffa 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -195,3 +195,34 @@ func TestStatsString(t *testing.T) { t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) } } + +func TestAddressWithPrefixSubnet(t *testing.T) { + tests := []struct { + addr Address + prefixLen int + subnetAddr Address + subnetMask AddressMask + }{ + {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, + {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, + {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + } + for _, tt := range tests { + ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} + gotSubnet := ap.Subnet() + wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) + if err != nil { + t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + continue + } + if gotSubnet != wantSubnet { + t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) + } + } +} diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index d78a162b8..9254c3dea 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "icmp_packet_list", out = "icmp_packet_list.go", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 451d3880e..3187b336b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,7 +15,6 @@ package icmp import ( - "encoding/binary" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -53,11 +52,11 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber - transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue // The following fields are used to manage the receive queue, and are @@ -74,27 +73,23 @@ type endpoint struct { sndBufSize int // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags - id stack.TransportEndpointID state endpointState - // bindNICID and bindAddr are set via calls to Bind(). They are used to - // reject attempts to send data or connect via a different NIC or - // address - bindNICID tcpip.NICID - bindAddr tcpip.Address - // regNICID is the default NIC to be used when callers don't specify a - // NIC. - regNICID tcpip.NICID - route stack.Route `state:"manual"` -} - -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + route stack.Route `state:"manual"` + ttl uint8 + stats tcpip.TransportEndpointStats `state:"nosave"` +} + +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return &endpoint{ - stack: stack, - netProto: netProto, - transProto: transProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: stateInitial, }, nil } @@ -105,7 +100,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -144,6 +139,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -205,7 +201,30 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -256,12 +275,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + if e.BindNICID != 0 { + if nicid != 0 && nicid != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicid = e.BindNICID } toCopy := *to @@ -272,7 +291,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -290,17 +309,17 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } - switch e.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - err = send4(route, e.id.LocalPort, v) + err = send4(route, e.ID.LocalPort, v, e.ttl) case header.IPv6ProtocolNumber: - err = send6(route, e.id.LocalPort, v) + err = send6(route, e.ID.LocalPort, v, e.ttl) } if err != nil { @@ -315,8 +334,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(o) + e.mu.Unlock() + } + + return nil +} + +// SetSockOptInt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return nil } @@ -332,6 +363,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -342,40 +385,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() + case *tcpip.KeepaliveEnabledOption: + *o = 0 return nil - case *tcpip.ReceiveBufferSizeOption: + case *tcpip.TTLOption: e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + *o = tcpip.TTLOption(e.ttl) e.rcvMu.Unlock() return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: return tcpip.ErrUnknownProtocolOption } } -func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } - // Set the ident to the user-specified port. Sequence number should - // already be set by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident) - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(icmpv4, data) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + icmpv4.SetIdent(ident) data = data[header.ICMPv4MinimumSize:] // Linux performs these basic checks. @@ -386,22 +422,24 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) } -func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return tcpip.ErrInvalidEndpointState } - // Set the ident. Sequence number is provided by the user. - binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident) - - hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength())) + hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) copy(icmpv6, data) - data = data[header.ICMPv6EchoMinimumSize:] + // Set the ident. Sequence number is provided by the user. + icmpv6.SetIdent(ident) + data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return tcpip.ErrInvalidEndpointState @@ -410,18 +448,21 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv6.SetChecksum(0) icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { return 0, tcpip.ErrNoRoute } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -442,16 +483,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { localPort := uint16(0) switch e.state { case stateBound, stateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicid != 0 && nicid != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicid = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -462,7 +503,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -484,9 +525,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id + e.ID = id e.route = r.Clone() - e.regNICID = nicid + e.RegisterNICID = nicid e.state = stateConnected @@ -541,14 +582,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil @@ -595,8 +636,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id - e.regNICID = addr.NIC + e.ID = id + e.RegisterNICID = addr.NIC // Mark endpoint as bound. e.state = stateBound @@ -619,8 +660,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { return err } - e.bindNICID = addr.NIC - e.bindAddr = addr.Addr + e.BindNICID = addr.NIC + e.BindAddr = addr.Addr return nil } @@ -631,9 +672,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, nil } @@ -647,9 +688,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -675,17 +716,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { // Only accept echo replies. - switch e.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(vv.First()) if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: h := header.ICMPv6(vv.First()) if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } } @@ -693,9 +736,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.rcvMu.Lock() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } @@ -717,7 +768,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv pkt.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() - + e.stats.PacketsReceived.Increment() // Notify any waiters that there's data to be read now. if wasEmpty { e.waiterQueue.Notify(waiter.EventIn) @@ -733,3 +784,17 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C func (e *endpoint) State() uint32 { return 0 } + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index c587b96b6..9d263c0ec 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -76,19 +76,19 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) if err != nil { panic(err) } - e.id.LocalAddress = e.route.LocalAddress - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + e.ID.LocalAddress = e.route.LocalAddress + } else if len(e.ID.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } - e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 7fdba5d56..bfb16f7c3 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -14,16 +14,14 @@ // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport // protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and -// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or -// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when -// calling stack.New(). Then endpoints can be created by passing +// must be added to the project, and activated on the stack by passing +// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport +// protocols when calling stack.New(). Then endpoints can be created by passing // icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number // when calling Stack.NewEndpoint(). package icmp import ( - "encoding/binary" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -35,15 +33,9 @@ import ( ) const ( - // ProtocolName4 is the string representation of the icmp protocol name. - ProtocolName4 = "icmp4" - // ProtocolNumber4 is the ICMP protocol number. ProtocolNumber4 = header.ICMPv4ProtocolNumber - // ProtocolName6 is the string representation of the icmp protocol name. - ProtocolName6 = "icmp6" - // ProtocolNumber6 is the IPv6-ICMP protocol number. ProtocolNumber6 = header.ICMPv6ProtocolNumber ) @@ -92,7 +84,7 @@ func (p *protocol) MinimumPacketSize() int { case ProtocolNumber4: return header.ICMPv4MinimumSize case ProtocolNumber6: - return header.ICMPv6EchoMinimumSize + return header.ICMPv6MinimumSize } panic(fmt.Sprint("unknown protocol number: ", p.number)) } @@ -101,16 +93,18 @@ func (p *protocol) MinimumPacketSize() int { func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { switch p.number { case ProtocolNumber4: - return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil + hdr := header.ICMPv4(v) + return 0, hdr.Ident(), nil case ProtocolNumber6: - return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil + hdr := header.ICMPv6(v) + return 0, hdr.Ident(), nil } panic(fmt.Sprint("unknown protocol number: ", p.number)) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } @@ -124,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol { - return &protocol{ProtocolNumber4} - }) +// NewProtocol4 returns an ICMPv4 transport protocol. +func NewProtocol4() stack.TransportProtocol { + return &protocol{ProtocolNumber4} +} - stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol { - return &protocol{ProtocolNumber6} - }) +// NewProtocol6 returns an ICMPv6 transport protocol. +func NewProtocol6() stack.TransportProtocol { + return &protocol{ProtocolNumber6} } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 7241f6c19..fba598d51 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "packet_list", out = "packet_list.go", diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 13e17e2a6..b4c660859 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -62,11 +62,10 @@ type packet struct { // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber - transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue associated bool @@ -84,18 +83,10 @@ type endpoint struct { closed bool connected bool bound bool - // registeredNIC is the NIC to which th endpoint is explicitly - // registered. Is set when Connect or Bind are used to specify a NIC. - registeredNIC tcpip.NICID - // boundNIC and boundAddr are set on calls to Bind(). When callers - // attempt actions that would invalidate the binding data (e.g. sending - // data via a NIC other than boundNIC), the endpoint will return an - // error. - boundNIC tcpip.NICID - boundAddr tcpip.Address // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. - route stack.Route `state:"manual"` + route stack.Route `state:"manual"` + stats tcpip.TransportEndpointStats `state:"nosave"` } // NewEndpoint returns a raw endpoint for the given protocols. @@ -104,15 +95,17 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if netProto != header.IPv4ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } - ep := &endpoint{ - stack: stack, - netProto: netProto, - transProto: transProto, + e := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, @@ -123,81 +116,82 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - ep.rcvBufSizeMax = 0 - ep.waiterQueue = nil - return ep, nil + e.rcvBufSizeMax = 0 + e.waiterQueue = nil + return e, nil } - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { return nil, err } - return ep, nil + return e, nil } // Close implements tcpip.Endpoint.Close. -func (ep *endpoint) Close() { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() - if ep.closed || !ep.associated { + if e.closed || !e.associated { return } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) - ep.rcvMu.Lock() - defer ep.rcvMu.Unlock() + e.rcvMu.Lock() + defer e.rcvMu.Unlock() // Clear the receive list. - ep.rcvClosed = true - ep.rcvBufSize = 0 - for !ep.rcvList.Empty() { - ep.rcvList.Remove(ep.rcvList.Front()) + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + e.rcvList.Remove(e.rcvList.Front()) } - if ep.connected { - ep.route.Release() - ep.connected = false + if e.connected { + e.route.Release() + e.connected = false } - ep.closed = true + e.closed = true - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (iptables.IPTables, error) { - return ep.stack.IPTables(), nil +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil } // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !ep.associated { +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if !e.associated { return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue } - ep.rcvMu.Lock() + e.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. - if ep.rcvList.Empty() { + if e.rcvList.Empty() { err := tcpip.ErrWouldBlock - if ep.rcvClosed { + if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return buffer.View{}, tcpip.ControlMessages{}, err } - packet := ep.rcvList.Front() - ep.rcvList.Remove(packet) - ep.rcvBufSize -= packet.data.Size() + packet := e.rcvList.Front() + e.rcvList.Remove(packet) + e.rcvBufSize -= packet.data.Size() - ep.rcvMu.Unlock() + e.rcvMu.Unlock() if addr != nil { *addr = packet.senderAddr @@ -207,31 +201,54 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - ep.mu.RLock() + e.mu.RLock() - if ep.closed { - ep.mu.RUnlock() + if e.closed { + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } - payloadBytes, err := payload.Get(payload.Size()) + payloadBytes, err := p.FullPayload() if err != nil { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, err } // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !ep.associated { + if !e.associated { ip := header.IPv4(payloadBytes) - if !ip.IsValid(payload.Size()) { - ep.mu.RUnlock() + if !ip.IsValid(len(payloadBytes)) { + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -252,66 +269,66 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if !ep.connected { - ep.mu.RUnlock() + if !e.connected { + e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if ep.route.IsResolutionRequired() { - savedRoute := &ep.route + if e.route.IsResolutionRequired() { + savedRoute := &e.route // Promote lock to exclusive if using a shared route, // given that it may need to change in finishWrite. - ep.mu.RUnlock() - ep.mu.Lock() + e.mu.RUnlock() + e.mu.Lock() // Make sure that the route didn't change during the // time we didn't hold the lock. - if !ep.connected || savedRoute != &ep.route { - ep.mu.Unlock() + if !e.connected || savedRoute != &e.route { + e.mu.Unlock() return 0, nil, tcpip.ErrInvalidEndpointState } - n, ch, err := ep.finishWrite(payloadBytes, savedRoute) - ep.mu.Unlock() + n, ch, err := e.finishWrite(payloadBytes, savedRoute) + e.mu.Unlock() return n, ch, err } - n, ch, err := ep.finishWrite(payloadBytes, &ep.route) - ep.mu.RUnlock() + n, ch, err := e.finishWrite(payloadBytes, &e.route) + e.mu.RUnlock() return n, ch, err } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC - if ep.bound && nic != 0 && nic != ep.boundNIC { - ep.mu.RUnlock() + if e.bound && nic != 0 && nic != e.BindNICID { + e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } // We don't support IPv6 yet, so this has to be an IPv4 address. if len(opts.To.Addr) != header.IPv4AddressSize { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } - // Find the route to the destination. If boundAddress is 0, + // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. - route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false) + route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, err } - n, ch, err := ep.finishWrite(payloadBytes, &route) + n, ch, err := e.finishWrite(payloadBytes, &route) route.Release() - ep.mu.RUnlock() + e.mu.RUnlock() return n, ch, err } // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -324,16 +341,16 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch ep.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - if !ep.associated { + if !e.associated { if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { return 0, nil, err } break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil { + if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { return 0, nil, err } @@ -345,7 +362,7 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -355,11 +372,11 @@ func (*endpoint) Disconnect() *tcpip.Error { } // Connect implements tcpip.Endpoint.Connect. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() - if ep.closed { + if e.closed { return tcpip.ErrInvalidEndpointState } @@ -369,15 +386,15 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } nic := addr.NIC - if ep.bound { - if ep.boundNIC == 0 { + if e.bound { + if e.BindNICID == 0 { // If we're bound, but not to a specific NIC, the NIC // in addr will be used. Nothing to do here. } else if addr.NIC == 0 { // If we're bound to a specific NIC, but addr doesn't // specify a NIC, use the bound NIC. - nic = ep.boundNIC - } else if addr.NIC != ep.boundNIC { + nic = e.BindNICID + } else if addr.NIC != e.BindNICID { // We're bound and addr specifies a NIC. They must be // the same. return tcpip.ErrInvalidEndpointState @@ -385,53 +402,53 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the destination. - route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false) + route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) if err != nil { return err } defer route.Release() - if ep.associated { + if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { return err } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = nic + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = nic } // Save the route we've connected via. - ep.route = route.Clone() - ep.connected = true + e.route = route.Clone() + e.connected = true return nil } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() - if !ep.connected { + if !e.connected { return tcpip.ErrNotConnected } return nil } // Listen implements tcpip.Endpoint.Listen. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } // Bind implements tcpip.Endpoint.Bind. -func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() // Callers must provide an IPv4 address or no network address (for // binding to a NIC, but not an address). @@ -440,94 +457,100 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 { + if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } - if ep.associated { + if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { return err } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = addr.NIC - ep.boundNIC = addr.NIC + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = addr.NIC + e.BindNICID = addr.NIC } - ep.boundAddr = addr.Addr - ep.bound = true + e.BindAddr = addr.Addr + e.bound = true return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } // Readiness implements tcpip.Endpoint.Readiness. -func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // The endpoint is always writable. result := waiter.EventOut & mask // Determine whether the endpoint is readable. if (mask & waiter.EventIn) != 0 { - ep.rcvMu.Lock() - if !ep.rcvList.Empty() || ep.rcvClosed { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { result |= waiter.EventIn } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() } return result } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 - ep.rcvMu.Lock() - if !ep.rcvList.Empty() { - p := ep.rcvList.Front() + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() v = p.data.Size() } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - ep.mu.Lock() - *o = tcpip.SendBufferSizeOption(ep.sndBufSize) - ep.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) - ep.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -538,37 +561,45 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { - ep.rcvMu.Lock() +func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. - if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax { - ep.stack.Stats().DroppedPackets.Increment() - ep.rcvMu.Unlock() + if e.rcvClosed { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() return } - if ep.bound { + if e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + if e.bound { // If bound to a NIC, only accept data for that NIC. - if ep.boundNIC != 0 && ep.boundNIC != route.NICID() { - ep.rcvMu.Unlock() + if e.BindNICID != 0 && e.BindNICID != route.NICID() { + e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress { - ep.rcvMu.Unlock() + if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + e.rcvMu.Unlock() return } } // If connected, only accept packets from the remote address we // connected to. - if ep.connected && ep.route.RemoteAddress != route.RemoteAddress { - ep.rcvMu.Unlock() + if e.connected && e.route.RemoteAddress != route.RemoteAddress { + e.rcvMu.Unlock() return } - wasEmpty := ep.rcvBufSize == 0 + wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. packet := &packet{ @@ -581,20 +612,34 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b combinedVV := netHeader.ToVectorisedView() combinedVV.Append(vv) packet.data = combinedVV.Clone(packet.views[:]) - packet.timestampNS = ep.stack.NowNanoseconds() - - ep.rcvList.PushBack(packet) - ep.rcvBufSize += packet.data.Size() + packet.timestampNS = e.stack.NowNanoseconds() - ep.rcvMu.Unlock() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + e.rcvMu.Unlock() + e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { - ep.waiterQueue.Notify(waiter.EventIn) + e.waiterQueue.Notify(waiter.EventIn) } } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (e *endpoint) State() uint32 { return 0 } + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 168953dec..a6c7cc43a 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -73,7 +73,7 @@ func (ep *endpoint) Resume(s *stack.Stack) { // If the endpoint is connected, re-connect. if ep.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) if err != nil { panic(err) } @@ -81,12 +81,12 @@ func (ep *endpoint) Resume(s *stack.Stack) { // If the endpoint is bound, re-bind. if ep.bound { - if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index 783c21e6b..a2512d666 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -20,13 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -type factory struct{} +// EndpointFactory implements stack.UnassociatedEndpointFactory. +type EndpointFactory struct{} // NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. -func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } - -func init() { - stack.RegisterUnassociatedFactory(factory{}) -} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 1ee1a53f8..aed70e06f 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "tcp_segment_list", @@ -47,6 +48,7 @@ go_library( "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/seqnum", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e9c5099ea..844959fa0 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -143,6 +143,15 @@ func decSynRcvdCount() { synRcvdCount.Unlock() } +// synCookiesInUse() returns true if the synRcvdCount is greater than +// SynRcvdCountThreshold. +func synCookiesInUse() bool { + synRcvdCount.Lock() + v := synRcvdCount.value + synRcvdCount.Unlock() + return v >= SynRcvdCountThreshold +} + // newListenContext creates a new listen context. func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ @@ -220,7 +229,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n := newEndpoint(l.stack, netProto, nil) n.v6only = l.v6only - n.id = s.id + n.ID = s.id n.boundNICID = s.route.NICID() n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} @@ -233,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { n.Close() return nil, err } @@ -281,7 +290,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head h.resetToSynRcvd(cookie, irs, opts) if err := h.execute(); err != nil { - ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -302,14 +310,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.id] = n + l.pendingEndpoints[n.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.id) + delete(l.pendingEndpoints, n.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -354,6 +362,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) @@ -405,6 +414,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } decSynRcvdCount() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } else { @@ -412,6 +422,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // is full then drop the syn. if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } @@ -430,7 +441,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TSEcr: opts.TSVal, MSS: uint16(mss), } - sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } @@ -442,10 +453,32 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // complete the connection at the time of retransmit if // the backlog has space. e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } + if !synCookiesInUse() { + // Send a reset as this is an ACK for which there is no + // half open connections and we are not using cookies + // yet. + // + // The only time we should reach here when a connection + // was opened and closed really quickly and a delayed + // ACK was received from the sender. + replyWithReset(s) + return + } + + // Since SYN cookies are in use this is potentially an ACK to a + // SYN-ACK we sent but don't have a half open connection state + // as cookies are being used to protect against a potential SYN + // flood. In such cases validate the cookie and if valid create + // a fully connected endpoint and deliver to the accept queue. + // + // If not, silently drop the ACK to avoid leaking information + // when under a potential syn flood attack. + // // Validate the cookie. data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) if !ok || int(data) >= len(mssTable) { @@ -475,6 +508,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } @@ -506,7 +540,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() v6only := e.v6only e.mu.Unlock() - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 00d2ae524..5ea036bea 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -238,6 +238,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() h.ep.state = StateSynRecv + ttl := h.ep.ttl h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), @@ -251,8 +252,10 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { SACKPermitted: rcvSynOpts.SACKPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) - + if ttl == 0 { + ttl = s.route.DefaultTTL() + } + h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -296,7 +299,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -383,6 +386,11 @@ func (h *handshake) resolveRoute() *tcpip.Error { switch index { case wakerForResolution: if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + if err == tcpip.ErrNoLinkAddress { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + } else if err != nil { + h.ep.stats.SendErrors.NoRoute.Increment() + } // Either success (err == nil) or failure. return err } @@ -460,7 +468,8 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + for h.state != handshakeCompleted { switch index, _ := s.Fetch(true); index { case wakerForResend: @@ -469,7 +478,7 @@ func (h *handshake) execute() *tcpip.Error { return tcpip.ErrTimeout } rt.Reset(timeOut) - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) case wakerForNotification: n := h.ep.fetchNotifications() @@ -579,16 +588,28 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { options := makeSynOptions(opts) - err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil) + // We ignore SYN send errors and let the callers re-attempt send. + if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil { + e.stats.SendErrors.SynSendToNetworkFailed.Increment() + } putOptions(options) - return err + return nil +} + +func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil { + e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() + return err + } + e.stats.SegmentsSent.Increment() + return nil } // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { optLen := len(opts) // Allocate a buffer for the TCP header. hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) @@ -624,12 +645,18 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } + if ttl == 0 { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + r.Stats().TCP.SegmentSendErrors.Increment() + return err + } r.Stats().TCP.SegmentsSent.Increment() if (flags & header.TCPFlagRst) != 0 { r.Stats().TCP.ResetsSent.Increment() } - - return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl) + return nil } // makeOptions makes an options slice. @@ -678,7 +705,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso) putOptions(options) return err } @@ -720,13 +747,18 @@ func (e *endpoint) handleClose() *tcpip.Error { return nil } -// resetConnectionLocked sends a RST segment and puts the endpoint in an error -// state with the given error code. This method must only be called from the -// protocol goroutine. +// resetConnectionLocked puts the endpoint in an error state with the given +// error code and sends a RST if and only if the error is not ErrConnectionReset +// indicating that the connection is being reset due to receiving a RST. This +// method must only be called from the protocol goroutine. func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + // Only send a reset if the connection is being aborted for a reason + // other than receiving a reset. e.state = StateError - e.hardError = err + e.HardError = err + if err != tcpip.ErrConnectionReset { + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + } } // completeWorkerLocked is called by the worker goroutine when it's about to @@ -806,7 +838,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() - return tcpip.ErrConnectionReset + return tcpip.ErrTimeout } // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with @@ -893,7 +925,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.mu.Lock() e.state = StateError - e.hardError = err + e.HardError = err // Lock released below. epilogue() @@ -1068,6 +1100,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.workMu.Lock() if err := funcs[v].f(); err != nil { e.mu.Lock() + // Ensure we release all endpoint registration and route + // references as the connection is now in an error + // state. + e.workerCleanup = true e.resetConnectionLocked(err) // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index c54610a87..dfaa4a559 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -42,7 +42,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { } } -func testV4Connect(t *testing.T, c *context.Context) { +func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventOut) @@ -55,12 +55,11 @@ func testV4Connect(t *testing.T, c *context.Context) { // Receive SYN packet. b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv4(t, b, synCheckers...) tcp := header.TCP(header.IPv4(b).Payload()) c.IRS = seqnum.Value(tcp.SequenceNumber()) @@ -76,14 +75,13 @@ func testV4Connect(t *testing.T, c *context.Context) { }) // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv4(t, c.GetPacket(), ackCheckers...) // Wait for connection to be established. select { @@ -152,7 +150,7 @@ func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { testV4Connect(t, c) } -func testV6Connect(t *testing.T, c *context.Context) { +func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { // Start connection attempt to IPv6 address. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventOut) @@ -165,12 +163,11 @@ func testV6Connect(t *testing.T, c *context.Context) { // Receive SYN packet. b := c.GetV6Packet() - checker.IPv6(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv6(t, b, synCheckers...) tcp := header.TCP(header.IPv6(b).Payload()) c.IRS = seqnum.Value(tcp.SequenceNumber()) @@ -186,14 +183,13 @@ func testV6Connect(t *testing.T, c *context.Context) { }) // Receive ACK packet. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv6(t, c.GetV6Packet(), ackCheckers...) // Wait for connection to be established. select { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ac927569a..a1b784b49 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "fmt" "math" "strings" @@ -26,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -170,6 +172,101 @@ type rcvBufAutoTuneParams struct { disabled bool } +// ReceiveErrors collect segment receive errors within transport layer. +type ReceiveErrors struct { + tcpip.ReceiveErrors + + // SegmentQueueDropped is the number of segments dropped due to + // a full segment queue. + SegmentQueueDropped tcpip.StatCounter + + // ChecksumErrors is the number of segments dropped due to bad checksums. + ChecksumErrors tcpip.StatCounter + + // ListenOverflowSynDrop is the number of times the listen queue overflowed + // and a SYN was dropped. + ListenOverflowSynDrop tcpip.StatCounter + + // ListenOverflowAckDrop is the number of times the final ACK + // in the handshake was dropped due to overflow. + ListenOverflowAckDrop tcpip.StatCounter + + // ZeroRcvWindowState is the number of times we advertised + // a zero receive window when rcvList is full. + ZeroRcvWindowState tcpip.StatCounter +} + +// SendErrors collect segment send errors within the transport layer. +type SendErrors struct { + tcpip.SendErrors + + // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent + // to the network endpoint. + SegmentSendToNetworkFailed tcpip.StatCounter + + // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent + // to the network endpoint. + SynSendToNetworkFailed tcpip.StatCounter + + // Retransmits is the number of TCP segments retransmitted. + Retransmits tcpip.StatCounter + + // FastRetransmit is the number of segments retransmitted in fast + // recovery. + FastRetransmit tcpip.StatCounter + + // Timeouts is the number of times the RTO expired. + Timeouts tcpip.StatCounter +} + +// Stats holds statistics about the endpoint. +type Stats struct { + // SegmentsReceived is the number of TCP segments received that + // the transport layer successfully parsed. + SegmentsReceived tcpip.StatCounter + + // SegmentsSent is the number of TCP segments sent. + SegmentsSent tcpip.StatCounter + + // FailedConnectionAttempts is the number of times we saw Connect and + // Accept errors. + FailedConnectionAttempts tcpip.StatCounter + + // ReceiveErrors collects segment receive errors within the + // transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects segment read errors from an endpoint read call. + ReadErrors tcpip.ReadErrors + + // SendErrors collects segment send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects segment write errors from an endpoint write call. + WriteErrors tcpip.WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*Stats) IsEndpointStats() {} + +// EndpointInfo holds useful information about a transport endpoint which +// can be queried by monitoring tools. +// +// +stateify savable +type EndpointInfo struct { + stack.TransportEndpointInfo + + // HardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. HardError is protected by endpoint mu. + HardError *tcpip.Error `state:".(string)"` +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*EndpointInfo) IsEndpointInfo() {} + // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -178,6 +275,8 @@ type rcvBufAutoTuneParams struct { // // +stateify savable type endpoint struct { + EndpointInfo + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -186,8 +285,7 @@ type endpoint struct { // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. - stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber + stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` // lastError represents the last error that the endpoint reported; @@ -218,7 +316,6 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID state EndpointState `state:".(EndpointState)"` @@ -226,6 +323,7 @@ type endpoint struct { isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` + ttl uint8 v6only bool isConnectNotified bool // TCP should never broadcast but Linux nevertheless supports enabling/ @@ -240,11 +338,6 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` - // hardError is meaningful only when state is stateError, it stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. hardError is protected by mu. - hardError *tcpip.Error `state:".(string)"` - // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -280,6 +373,9 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. + bindToDevice tcpip.NICID + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -393,13 +489,19 @@ type endpoint struct { probe stack.TCPProbeFunc `state:"nosave"` // The following are only used to assist the restore run to re-connect. - bindAddress tcpip.Address connectingAddress tcpip.Address // amss is the advertised MSS to the peer by this endpoint. amss uint16 + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + gso *stack.GSO + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats Stats `state:"nosave"` } // StopWork halts packet processing. Only to be used in tests. @@ -427,10 +529,15 @@ type keepalive struct { waker sleep.Waker `state:"nosave"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + EndpointInfo: EndpointInfo{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + }, waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, @@ -446,26 +553,26 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite } var ss SendBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } var rs ReceiveBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } var cs tcpip.CongestionControlOption - if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } var mrb tcpip.ModerateReceiveBufferOption - if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - if p := stack.GetTCPProbe(); p != nil { + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -564,11 +671,11 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -625,12 +732,12 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -731,11 +838,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.hardError + he := e.HardError e.mu.RUnlock() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } + e.stats.ReadErrors.InvalidEndpointState.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -744,6 +852,9 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RUnlock() + if err == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } return v, tcpip.ControlMessages{}, err } @@ -787,7 +898,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { if !e.state.connected() { switch e.state { case StateError: - return 0, e.hardError + return 0, e.HardError default: return 0, tcpip.ErrClosedForSend } @@ -806,7 +917,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -818,50 +929,57 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha if err != nil { e.sndBufMu.Unlock() e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } - e.sndBufMu.Unlock() - e.mu.RUnlock() - - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + e.mu.RUnlock() } - // Copy in memory without holding sndBufMu so that worker goroutine can - // make progress independent of this operation. - v, perr := p.Get(avail) - if perr != nil { + // Fetch data. + v, perr := p.Payload(avail) + if perr != nil || len(v) == 0 { + if opts.Atomic { // See above. + e.sndBufMu.Unlock() + e.mu.RUnlock() + } + // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - e.mu.RLock() - e.sndBufMu.Lock() + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a - // write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - return 0, nil, err - } + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } - // Discard any excess data copied in due to avail being reduced due to a - // simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } // Add data to the send queue. - l := len(v) - s := newSegmentFromView(&e.route, e.id, v) - e.sndBufUsed += l - e.sndBufInQueue += seqnum.Size(l) + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() // Release the endpoint lock to prevent deadlocks due to lock // order inversion when acquiring workMu. @@ -875,7 +993,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return int64(l), nil, nil + + return int64(len(v)), nil, nil } // Peek reads data without consuming it from the endpoint. @@ -889,8 +1008,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.state; !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.hardError + return 0, tcpip.ControlMessages{}, e.HardError } + e.stats.ReadErrors.InvalidEndpointState.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -899,6 +1019,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.state.connected() { + e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -946,62 +1067,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil - - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.reuseAddr = v != 0 - e.mu.Unlock() - return nil - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.mu.Lock() - e.userMSS = int(userMSS) - e.mu.Unlock() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - +// SetSockOptInt sets a socket option. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1065,9 +1133,87 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sndBufMu.Unlock() return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch v := opt.(type) { + case tcpip.DelayOption: + if v == 0 { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.delay, 1) + } + return nil + + case tcpip.CorkOption: + if v == 0 { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicid, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicid + return nil + } + } + return tcpip.ErrUnknownDevice + + case tcpip.QuickAckOption: + if v == 0 { + atomic.StoreUint32(&e.slowAck, 1) + } else { + atomic.StoreUint32(&e.slowAck, 0) + } + return nil + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -1082,6 +1228,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.v6only = v != 0 return nil + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + return nil + case tcpip.KeepaliveEnabledOption: e.keepalive.Lock() e.keepalive.enabled = v != 0 @@ -1150,6 +1302,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. return tcpip.ErrNoSuchFile + + case tcpip.IPv4TOSOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + default: return nil } @@ -1176,6 +1345,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + v := e.sndBufSize + e.sndBufMu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + v := e.rcvBufSize + e.rcvListMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -1198,18 +1379,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.sndBufMu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) - e.rcvListMu.Unlock() - return nil - case *tcpip.DelayOption: *o = 0 if v := atomic.LoadUint32(&e.delay); v != 0 { @@ -1246,6 +1415,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = "" + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1255,7 +1434,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -1269,6 +1448,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.mu.RLock() @@ -1333,13 +1518,25 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { // Fail if using a v4 mapped address on a v6only endpoint. if e.v6only { @@ -1355,7 +1552,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -1369,7 +1566,12 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return e.connect(addr, true, true) + err := e.connect(addr, true, true) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err } // connect connects the endpoint to its peer. In the normal non-S/R case, the @@ -1378,14 +1580,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() connectingAddr := addr.Addr @@ -1430,29 +1627,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return tcpip.ErrAlreadyConnecting case StateError: - return e.hardError + return e.HardError default: return tcpip.ErrInvalidEndpointState } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() - origID := e.id + origID := e.ID netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.id.LocalAddress = r.LocalAddress - e.id.RemoteAddress = r.RemoteAddress - e.id.RemotePort = addr.Port + e.ID.LocalAddress = r.LocalAddress + e.ID.RemoteAddress = r.RemoteAddress + e.ID.RemotePort = addr.Port - if e.id.LocalPort != 0 { + if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1461,20 +1658,35 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.id.LocalAddress == e.id.RemoteAddress - if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - if sameAddr && p == e.id.RemotePort { + sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + + // Calculate a port offset based on the destination IP/port and + // src IP to ensure that for a given tuple (srcIP, destIP, + // destPort) the offset used as a starting point is the same to + // ensure that we can cycle through the port space effectively. + h := jenkins.Sum32(e.stack.PortSeed()) + h.Write([]byte(e.ID.LocalAddress)) + h.Write([]byte(e.ID.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) + h.Write(portBuf) + portOffset := h.Sum32() + + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.ID.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { + // reusePort is false below because connect cannot reuse a port even if + // reusePort was set. + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { return false, nil } - id := e.id + id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: - e.id = id + e.ID = id return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1490,7 +1702,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -1509,7 +1721,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.segmentQueue.mu.Lock() for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.id + s.id = e.ID s.route = r.Clone() e.sndWaker.Assert() } @@ -1569,7 +1781,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.id, nil) + s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ @@ -1597,14 +1809,18 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { +func (e *endpoint) Listen(backlog int) *tcpip.Error { + err := e.listen(backlog) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err +} + +func (e *endpoint) listen(backlog int) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -1630,11 +1846,12 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { + e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { return err } @@ -1698,7 +1915,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { return tcpip.ErrAlreadyBound } - e.bindAddress = addr.Addr + e.BindAddr = addr.Addr netProto, err := e.checkV4Mapped(&addr) if err != nil { return err @@ -1715,26 +1932,26 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) if err != nil { return err } e.isPortReserved = true e.effectiveNetProtos = netProtos - e.id.LocalPort = port + e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func() { + defer func(bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil - e.id.LocalPort = 0 - e.id.LocalAddress = "" + e.ID.LocalPort = 0 + e.ID.LocalAddress = "" e.boundNICID = 0 } - }() + }(e.bindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -1745,7 +1962,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } e.boundNICID = nic - e.id.LocalAddress = addr.Addr + e.ID.LocalAddress = addr.Addr } // Mark endpoint as bound. @@ -1760,8 +1977,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -1776,8 +1993,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, NIC: e.boundNICID, }, nil } @@ -1789,6 +2006,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() return } @@ -1796,11 +2014,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if !s.csumValid { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() return } e.stack.Stats().TCP.ValidSegmentsReceived.Increment() + e.stats.SegmentsReceived.Increment() if (s.flags & header.TCPFlagRst) != 0 { e.stack.Stats().TCP.ResetsReceived.Increment() } @@ -1811,6 +2031,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } else { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.SegmentQueueDropped.Increment() s.decRef() } } @@ -1860,6 +2081,7 @@ func (e *endpoint) readyToRead(s *segment) { // that a subsequent read of the segment will correctly trigger // a non-zero notification. if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() e.zeroWindow = true } e.rcvList.PushBack(s) @@ -2012,7 +2234,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Copy EndpointID. e.mu.Lock() - s.ID = stack.TCPEndpointID(e.id) + s.ID = stack.TCPEndpointID(e.ID) e.mu.Unlock() // Copy endpoint rcv state. @@ -2119,7 +2341,7 @@ func (e *endpoint) initGSO() { gso.Type = stack.GSOTCPv6 gso.L3HdrLen = header.IPv6MinimumSize default: - panic(fmt.Sprintf("Unknown netProto: %v", e.netProto)) + panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) } gso.NeedsCsum = true gso.CsumOffset = header.TCPChecksumOffset @@ -2135,6 +2357,20 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.EndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + func mssForRoute(r *stack.Route) uint16 { return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 831389ec7..eae17237e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,7 +55,7 @@ func (e *endpoint) beforeSave() { case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() @@ -190,10 +190,10 @@ func (e *endpoint) Resume(s *stack.Stack) { bind := func() { e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress + if len(e.BindAddr) == 0 { + e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { + if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil { panic("endpoint binding failed: " + err.String()) } } @@ -202,19 +202,19 @@ func (e *endpoint) Resume(s *stack.Stack) { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.id.RemoteAddress + e.connectingAddress = e.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } connectedLoading.Done() @@ -236,7 +236,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -288,21 +288,21 @@ func (e *endpoint) loadLastError(s string) { } // saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { +func (e *EndpointInfo) saveHardError() string { + if e.HardError == nil { return "" } - return e.hardError.String() + return e.HardError.String() } // loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { +func (e *EndpointInfo) loadHardError(s string) { if s == "" { return } - e.hardError = loadError(s) + e.HardError = loadError(s) } var messageToError map[string]*tcpip.Error diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index ee04dcfcc..db40785d3 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -14,7 +14,7 @@ // Package tcp contains the implementation of the TCP transport protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the +// activated on the stack by passing tcp.NewProtocol() as one of the // transport protocols when calling stack.New(). Then endpoints can be created // by passing tcp.ProtocolNumber as the transport protocol number when calling // Stack.NewEndpoint(). @@ -34,9 +34,6 @@ import ( ) const ( - // ProtocolName is the string representation of the tcp protocol name. - ProtocolName = "tcp" - // ProtocolNumber is the tcp protocol number. ProtocolNumber = header.TCPProtocolNumber @@ -129,7 +126,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() @@ -156,7 +153,7 @@ func replyWithReset(s *segment) { ack := s.sequenceNumber.Add(s.logicalLen()) - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. @@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{ - sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, - congestionControl: ccReno, - availableCongestionControl: []string{ccReno, ccCubic}, - } - }) +// NewProtocol returns a TCP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{ + sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, + congestionControl: ccReno, + availableCongestionControl: []string{ccReno, ccCubic}, + } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 1f9b1e0ef..8332a0179 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -417,6 +417,7 @@ func (s *sender) resendSegment() { s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() + s.ep.stats.SendErrors.FastRetransmit.Increment() // Run SetPipe() as per RFC 6675 section 5 Step 4.4 s.SetPipe() @@ -435,6 +436,7 @@ func (s *sender) retransmitTimerExpired() bool { } s.ep.stack.Stats().TCP.Timeouts.Increment() + s.ep.stats.SendErrors.Timeouts.Increment() // Give up if we've waited more than a minute since the last resend. if s.rto >= 60*time.Second { @@ -664,7 +666,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se segEnd = seg.sequenceNumber.Add(1) // Transition to FIN-WAIT1 state since we're initiating an active close. s.ep.mu.Lock() - s.ep.state = StateFinWait1 + switch s.ep.state { + case StateCloseWait: + // We've already received a FIN and are now sending our own. The + // sender is now awaiting a final ACK for this FIN. + s.ep.state = StateLastAck + default: + s.ep.state = StateFinWait1 + } s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. @@ -1181,6 +1190,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { func (s *sender) sendSegment(seg *segment) *tcpip.Error { if !seg.xmitTime.IsZero() { s.ep.stack.Stats().TCP.Retransmits.Increment() + s.ep.stats.SendErrors.Retransmits.Increment() if s.sndCwnd < s.sndSsthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 272bbcdbd..782d7b42c 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { enableCUBIC(t, c) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -500,6 +500,14 @@ func TestRetransmit(t *testing.T) { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 4e7f1a740..afea124ec 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -520,10 +520,18 @@ func TestSACKRecovery(t *testing.T) { t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + } + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) // Acknowledge all pending data to recover point. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index f79b8ec5f..6d022a266 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -97,9 +97,12 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) } } @@ -122,6 +125,9 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) + } } func TestTCPSegmentsSentIncrement(t *testing.T) { @@ -131,11 +137,14 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { stats := c.Stack().Stats() // SYN and ACK want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { + t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) + } } func TestTCPResetsSentIncrement(t *testing.T) { @@ -190,21 +199,122 @@ func TestTCPResetsSentIncrement(t *testing.T) { } } +// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates +// a RST if an ACK is received on the listening socket for which there is no +// active handshake in progress and we are not using SYN cookies. +func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + c.GetPacket() + + // Now resend the same ACK, this ACK should generate a RST as there + // should be no endpoint in SYN-RCVD state and we are not using + // syn-cookies yet. The reason we send the same ACK is we need a valid + // cookie(IRS) generated by the netstack without which the ACK will be + // rejected. + c.SendPacket(nil, ackHeaders) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) +} + func TestTCPResetsReceivedIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() stats := c.Stack().Stats() want := stats.TCP.ResetsReceived.Value() + 1 - ackNum := seqnum.Value(789) + iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(ackNum, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, - SeqNum: c.IRS.Add(2), - AckNum: ackNum.Add(2), + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(1), RcvWnd: rcvWnd, Flags: header.TCPFlagRst, }) @@ -214,18 +324,43 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { } } +func TestTCPResetsDoNotGenerateResets(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ResetsReceived.Value() + 1 + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + Flags: header.TCPFlagRst, + }) + + if got := stats.TCP.ResetsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) + } + c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) +} + func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) } func TestNonBlockingClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -241,7 +376,7 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -291,7 +426,7 @@ func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -339,11 +474,172 @@ func TestSimpleReceive(t *testing.T) { ) } +func TestTOSV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + + const tos = 0xC0 + if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { + t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + } + + var v tcpip.IPv4TOSOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv4TOSOption(tos); v != want { + t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testV4Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestTrafficClassV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + const tos = 0xC0 + if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { + t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + } + + var v tcpip.IPv6TrafficClassOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv6TrafficClassOption(tos); v != want { + t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + // Test the connection request. + testV6Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetV6Packet() + checker.IPv6(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestConnectBindToDevice(t *testing.T) { + for _, test := range []struct { + name string + device string + want tcp.EndpointState + }{ + {"RightDevice", "nic1", tcp.StateEstablished}, + {"WrongDevice", "nic2", tcp.StateSynSent}, + {"AnyDevice", "", tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + bindToDevice := tcpip.BindToDeviceOption(test.device) + c.EP.SetSockOpt(bindToDevice) + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + c.GetPacket() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + }) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -431,8 +727,7 @@ func TestOutOfOrderFlood(t *testing.T) { defer c.Cleanup() // Create a new connection with initial window size of 10. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -505,7 +800,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -574,7 +869,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -659,7 +954,7 @@ func TestShutdownRead(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -672,14 +967,17 @@ func TestShutdownRead(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) + } } func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -729,6 +1027,11 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("got data = %v, want = %v", v, data) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) + } + // Check that we get an ACK for the newly non-zero window. checker.IPv4(t, c.GetPacket(), checker.TCP( @@ -746,11 +1049,9 @@ func TestNoWindowShrinking(t *testing.T) { defer c.Cleanup() // Start off with a window size of 10, then shrink it to 5. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) - opt = 5 - if err := c.EP.SetSockOpt(opt); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { t.Fatalf("SetSockOpt failed: %v", err) } @@ -850,7 +1151,7 @@ func TestSimpleSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -891,7 +1192,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 0, nil) + c.CreateConnected(789, 0, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -949,8 +1250,7 @@ func TestScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -984,8 +1284,7 @@ func TestNonScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1025,7 +1324,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1098,7 +1397,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1167,8 +1466,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Set the window size such that a window scale of 4 will be used. const wnd = 65535 * 10 const ws = uint32(4) - opt := tcpip.ReceiveBufferSizeOption(wnd) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1273,7 +1571,7 @@ func TestSegmentMerging(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Prevent the endpoint from processing packets. test.stop(c.EP) @@ -1323,7 +1621,7 @@ func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1371,7 +1669,7 @@ func TestUndelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1453,7 +1751,7 @@ func TestMSSNotDelayed(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -1569,16 +1867,44 @@ func TestSendGreaterThanMTU(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) testBrokenUpWrite(t, c, maxPayload) } +func TestSetTTL(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := context.New(t, 65535) + defer c.Cleanup() + + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + t.Fatalf("SetSockOpt failed: %v", err) + } + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + + checker.IPv4(t, b, checker.TTL(wantTTL)) + }) + } +} + func TestActiveSendMSSLessThanMTU(t *testing.T) { const maxPayload = 100 c := context.New(t, 65535) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) testBrokenUpWrite(t, c, maxPayload) @@ -1601,7 +1927,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1745,7 +2071,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 2 - if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1847,7 +2173,7 @@ func TestReceiveOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -1878,13 +2204,20 @@ loop: t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) } } + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { + t.Fatalf("got EP state is not StateError") + } } func TestSendOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -1909,7 +2242,7 @@ func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -1952,7 +2285,7 @@ func TestFinRetransmit(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2006,7 +2339,7 @@ func TestFinWithNoPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. view := buffer.NewView(10) @@ -2077,7 +2410,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write enough segments to fill the congestion window before ACK'ing // any of them. @@ -2165,7 +2498,7 @@ func TestFinWithPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) @@ -2251,7 +2584,7 @@ func TestFinWithPartialAck(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. @@ -2383,7 +2716,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { defer c.Cleanup() maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), header.TCPOptionWS, 3, scale, header.TCPOptionNOP, }) @@ -2433,7 +2766,7 @@ func TestScaledSendWindow(t *testing.T) { func TestReceivedValidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ValidSegmentsReceived.Value() + 1 @@ -2449,12 +2782,23 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { + t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) + } + // Ensure there were no errors during handshake. If these stats have + // incremented, then the connection should not have been established. + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) + } + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) + } } func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.InvalidSegmentsReceived.Value() + 1 vv := c.BuildSegment(nil, &context.Headers{ @@ -2473,12 +2817,15 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } } func TestReceivedIncorrectChecksumIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ChecksumErrors.Value() + 1 vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ @@ -2499,6 +2846,9 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { if got := stats.TCP.ChecksumErrors.Value(); got != want { t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) + } } func TestReceivedSegmentQueuing(t *testing.T) { @@ -2509,7 +2859,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send 200 segments. data := []byte{1, 2, 3} @@ -2555,7 +2905,7 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -2730,8 +3080,8 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2743,8 +3093,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2754,7 +3104,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { } func TestDefaultBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) @@ -2800,7 +3153,10 @@ func TestDefaultBufferSizes(t *testing.T) { } func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) @@ -2819,37 +3175,96 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) } +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer ep.Close() + + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } + + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) + } + + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s + } + + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) + } +} + func makeStack() (*stack.Stack, *tcpip.Error) { - s := stack.New([]string{ - ipv4.ProtocolName, - ipv6.ProtocolName, - }, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + ipv6.NewProtocol(), + }, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) id := loopback.New() if testing.Verbose() { @@ -3105,7 +3520,7 @@ func TestPathMTUDiscovery(t *testing.T) { // Create new connection with MSS of 1460. const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -3182,7 +3597,7 @@ func TestTCPEndpointProbe(t *testing.T) { invoked <- struct{}{} }) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -3356,7 +3771,7 @@ func TestKeepalive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) @@ -3459,8 +3874,8 @@ func TestKeepalive(t *testing.T) { ), ) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } } @@ -3886,6 +4301,9 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) + } we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -3924,6 +4342,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + // Expect InvalidEndpointState errors on a read at this point. + if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + } + if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { + t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + } + if err := ep.Listen(10); err != nil { t.Fatalf("Listen failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 272481aa0..ef823e4ae 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -137,7 +137,10 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Allow minimum send/receive buffer sizes to be 1 during tests. if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { @@ -150,11 +153,19 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - id, linkEP := channel.New(1000, mtu, "") + ep := channel.New(1000, mtu, "") + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { + wep = sniffer.New(ep) + } + if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) if testing.Verbose() { - id = sniffer.New(id) + wep2 = sniffer.New(channel.New(1000, mtu, "")) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -180,7 +191,7 @@ func New(t *testing.T, mtu uint32) *Context { return &Context{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), } } @@ -267,7 +278,7 @@ func (c *Context) GetPacketNonBlocking() []byte { // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2)) + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) if len(buf) > maxTotalSize { buf = buf[:maxTotalSize] } @@ -286,9 +297,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) icmp.SetType(typ) icmp.SetCode(code) - - copy(icmp[header.ICMPv4PayloadOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2) + const icmpv4VariableHeaderOffset = 4 + copy(icmp[icmpv4VariableHeaderOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) @@ -511,7 +522,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { } // CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } @@ -584,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.Port = tcpHdr.SourcePort() } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { +// Create creates a TCP endpoint. +func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -597,11 +604,20 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("NewEndpoint failed: %v", err) } - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + if epRcvBuf != -1 { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { + c.Create(epRcvBuf) c.Connect(iss, rcvWnd, options) } diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 4bec48c0f..43fcc27f0 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index ac2666f69..c9460aa0d 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "udp_packet_list", @@ -50,6 +51,7 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index ac5905772..6e87245b7 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "math" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -37,15 +36,35 @@ type udpPacket struct { views [8]buffer.View `state:"nosave"` } -type endpointState int +// EndpointState represents the state of a UDP endpoint. +type EndpointState uint32 +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. const ( - stateInitial endpointState = iota - stateBound - stateConnected - stateClosed + StateInitial EndpointState = iota + StateBound + StateConnected + StateClosed ) +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnected: + return "CONNECTING" + case StateClosed: + return "CLOSED" + default: + return "UNKNOWN" + } +} + // endpoint represents a UDP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -55,10 +74,11 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue // The following fields are used to manage the receive queue, and are @@ -73,20 +93,23 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int - id stack.TransportEndpointID - state endpointState - bindNICID tcpip.NICID - regNICID tcpip.NICID + state EndpointState route stack.Route `state:"manual"` dstPort uint16 v6only bool + ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID multicastLoop bool reusePort bool + bindToDevice tcpip.NICID broadcast bool + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -101,6 +124,9 @@ type endpoint struct { // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). effectiveNetProtos []tcpip.NetworkProtocolNumber + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats tcpip.TransportEndpointStats `state:"nosave"` } // +stateify savable @@ -109,10 +135,13 @@ type multicastMembership struct { multicastAddr tcpip.Address } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { return &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, waiterQueue: waiterQueue, // RFC 1075 section 5.4 recommends a TTL of 1 for membership // requests. @@ -130,6 +159,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: StateInitial, } } @@ -140,13 +170,13 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { - case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + case StateBound, StateConnected: + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) } for _, mem := range e.multicastMemberships { - e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } e.multicastMemberships = nil @@ -163,7 +193,7 @@ func (e *endpoint) Close() { e.route.Release() // Update the state. - e.state = stateClosed + e.state = StateClosed e.mu.Unlock() @@ -186,6 +216,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -211,11 +242,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // Returns true for retry if preparation should be retried. func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { switch e.state { - case stateInitial: - case stateConnected: + case StateInitial: + case StateConnected: return false, nil - case stateBound: + case StateBound: if to == nil { return false, tcpip.ErrDestinationRequired } @@ -232,7 +263,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != stateInitial { + if e.state != StateInitial { return true, nil } @@ -248,7 +279,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // configured multicast interface if no interface is specified and the // specified address is a multicast address. func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { - localAddr := e.id.LocalAddress + localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" @@ -273,17 +304,35 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - if p.Size() > math.MaxUint16 { - // Payload can't possibly fit in a packet. - return 0, nil, tcpip.ErrMessageTooLong - } - to := opts.To e.mu.RLock() @@ -322,7 +371,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha defer e.mu.Unlock() // Recheck state after lock was re-acquired. - if e.state != stateConnected { + if e.state != StateConnected { return 0, nil, tcpip.ErrInvalidEndpointState } } @@ -330,12 +379,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + if e.BindNICID != 0 { + if nicid != 0 && nicid != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicid = e.BindNICID } if to.Addr == header.IPv4Broadcast && !e.broadcast { @@ -366,17 +415,25 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } + + ttl := e.ttl + useDefaultTTL := ttl == 0 - ttl := route.DefaultTTL() if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { ttl = e.multicastTTL + // Multicast allows a 0 TTL. + useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -387,12 +444,17 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -400,12 +462,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } e.v6only = v != 0 + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -440,7 +507,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if e.bindNICID != 0 && e.bindNICID != nic { + if e.BindNICID != 0 && e.BindNICID != nic { return tcpip.ErrInvalidEndpointState } @@ -467,7 +534,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -484,7 +551,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -505,7 +572,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -527,7 +594,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrBadLocalAddress } - if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -544,12 +611,39 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.reusePort = v != 0 e.mu.Unlock() + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicid, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicid + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 e.mu.Unlock() return nil + + case tcpip.IPv4TOSOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + return nil } return nil } @@ -566,7 +660,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil } + return -1, tcpip.ErrUnknownProtocolOption } @@ -576,21 +683,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -604,6 +699,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.MulticastTTLOption: e.mu.Lock() *o = tcpip.MulticastTTLOption(e.multicastTTL) @@ -638,6 +739,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = tcpip.BindToDeviceOption("") + return nil + case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -653,6 +764,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -660,7 +783,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -683,14 +806,21 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u udp.SetChecksum(^udp.CalculateChecksum(xsum)) } + if useDefaultTTL { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + r.Stats().UDP.PacketSendErrors.Increment() + return err + } + // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - - return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl) + return nil } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if len(addr.Addr) == 0 { return netProto, nil } @@ -707,14 +837,14 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t } // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.id.LocalAddress) == 16 { + if !allowMismatch && len(e.ID.LocalAddress) == 16 { return 0, tcpip.ErrNetworkUnreachable } } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -726,28 +856,32 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.state != stateConnected { + if e.state != StateConnected { return nil } id := stack.TransportEndpointID{} // Exclude ephemerally bound endpoints. - if e.bindNICID != 0 || e.id.LocalAddress == "" { + if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error id = stack.TransportEndpointID{ - LocalPort: e.id.LocalPort, - LocalAddress: e.id.LocalAddress, + LocalPort: e.ID.LocalPort, + LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } - e.state = stateBound + e.state = StateBound } else { - e.state = stateInitial + if e.ID.LocalPort != 0 { + // Release the ephemeral port. + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + } + e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.id = id + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.ID = id e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -772,18 +906,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicid := addr.NIC var localPort uint16 switch e.state { - case stateInitial: - case stateBound, stateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + case StateInitial: + case StateBound, StateConnected: + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicid != 0 && nicid != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicid = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -795,13 +929,13 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() id := stack.TransportEndpointID{ - LocalAddress: e.id.LocalAddress, + LocalAddress: e.ID.LocalAddress, LocalPort: localPort, RemotePort: addr.Port, RemoteAddress: r.RemoteAddress, } - if e.state == stateInitial { + if e.state == StateInitial { id.LocalAddress = r.LocalAddress } @@ -822,17 +956,17 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Remove the old registration. - if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + if e.ID.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) } - e.id = id + e.ID = id e.route = r.Clone() e.dstPort = addr.Port - e.regNICID = nicid + e.RegisterNICID = nicid e.effectiveNetProtos = netProtos - e.state = stateConnected + e.state = StateConnected e.rcvMu.Lock() e.rcvReady = true @@ -854,7 +988,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. - if e.state != stateBound && e.state != stateConnected { + if e.state != StateBound && e.state != StateConnected { return tcpip.ErrNotConnected } @@ -885,17 +1019,17 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if e.id.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) + if e.ID.LocalPort == 0 { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { return id, err } id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) } return id, err } @@ -903,7 +1037,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -941,12 +1075,12 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id - e.regNICID = nicid + e.ID = id + e.RegisterNICID = nicid e.effectiveNetProtos = netProtos // Mark endpoint as bound. - e.state = stateBound + e.state = StateBound e.rcvMu.Lock() e.rcvReady = true @@ -967,7 +1101,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // Save the effective NICID generated by bindLocked. - e.bindNICID = e.regNICID + e.BindNICID = e.RegisterNICID return nil } @@ -978,9 +1112,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, nil } @@ -989,14 +1123,14 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { + if e.state != StateConnected { return tcpip.FullAddress{}, tcpip.ErrNotConnected } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -1026,6 +1160,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if int(hdr.Length()) > vv.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } @@ -1033,11 +1168,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() + e.stats.PacketsReceived.Increment() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } @@ -1069,10 +1213,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } -// State implements socket.Socket.State. +// State implements tcpip.Endpoint.State. func (e *endpoint) State() uint32 { - // TODO(b/112063468): Translate internal state to values returned by Linux. - return 0 + e.mu.Lock() + defer e.mu.Unlock() + return uint32(e.state) +} + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats } func isBroadcastOrMulticast(a tcpip.Address) bool { diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 5cbb56120..b227e353b 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -72,12 +72,12 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } } - if e.state != stateBound && e.state != stateConnected { + if e.state != StateBound && e.state != StateConnected { return } @@ -92,14 +92,14 @@ func (e *endpoint) Resume(s *stack.Stack) { } var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + if e.state == StateConnected { + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) if err != nil { panic(err) } - } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } @@ -107,9 +107,9 @@ func (e *endpoint) Resume(s *stack.Stack) { // Our saved state had a port, but we don't actually have a // reservation. We need to remove the port from our state, but still // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id := e.ID + e.ID.LocalPort = 0 + e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a874fc9fd..d399ec722 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -74,17 +74,17 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { ep.Close() return nil, err } - ep.id = r.id + ep.ID = r.id ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort - ep.regNICID = r.route.NICID() + ep.RegisterNICID = r.route.NICID() - ep.state = stateConnected + ep.state = StateConnected ep.rcvMu.Lock() ep.rcvReady = true diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f76e7fbe1..de026880f 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -14,7 +14,7 @@ // Package udp contains the implementation of the UDP transport protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.ProtocolName (or "udp") as one of the +// activated on the stack by passing udp.NewProtocol() as one of the // transport protocols when calling stack.New(). Then endpoints can be created // by passing udp.ProtocolNumber as the transport protocol number when calling // Stack.NewEndpoint(). @@ -30,9 +30,6 @@ import ( ) const ( - // ProtocolName is the string representation of the udp protocol name. - ProtocolName = "udp" - // ProtocolNumber is the udp protocol number. ProtocolNumber = header.UDPProtocolNumber ) @@ -69,7 +66,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { + // Get the header then trim it from the view. + hdr := header.UDP(vv.First()) + if int(hdr.Length()) > vv.Size() { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + // TODO(b/129426613): only send an ICMP message if UDP checksum is valid. + + // Only send ICMP error if the address is not a multicast/broadcast + // v4/v6 address or the source is not the unspecified address. + // + // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { + return true + } + + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported; or + // + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + switch len(id.LocalAddress) { + case header.IPv4AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() + return true + } + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 where it mentioned that at least 8 + // bytes of the payload must be included. Today linux and other + // systems implement the] RFC1812 definition and not the original + // RFC 1122 requirement. + mtu := int(r.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + payloadLen := len(netHeader) + vv.Size() + if payloadLen > available { + payloadLen = available + } + + payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) + payload.Append(vv) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4DstUnreachable) + pkt.SetCode(header.ICMPv4PortUnreachable) + pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) + r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + + case header.IPv6AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() + return true + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(r.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize + available := int(mtu) - headerLen + payloadLen := len(netHeader) + vv.Size() + if payloadLen > available { + payloadLen = available + } + payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) + payload.Append(vv) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) + pkt.SetType(header.ICMPv6DstUnreachable) + pkt.SetCode(header.ICMPv6PortUnreachable) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) + r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + } return true } @@ -83,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{} - }) +// NewProtocol returns a UDP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 9da6edce2..b724d788c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -17,7 +17,6 @@ package udp_test import ( "bytes" "fmt" - "math" "math/rand" "testing" "time" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -274,13 +274,17 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + ep := channel.New(256, mtu, "") + wep := stack.LinkEndpoint(ep) - id, linkEP := channel.New(256, mtu, "") if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -306,7 +310,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } @@ -380,15 +384,17 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h) + c.injectV4Packet(payload, &h, true /* valid */) } else { - c.injectV6Packet(payload, &h) + c.injectV6Packet(payload, &h, true /* valid */) } } // injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -405,10 +411,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) + l := uint16(header.UDPMinimumSize + len(payload)) + if !valid { + // Change the UDP payload length to corrupt the header + // as requested by the caller. + l++ + } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), + Length: l, }) // Calculate the UDP pseudo-header checksum. @@ -422,9 +434,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -// injectV6Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { +// injectV4Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -461,101 +475,78 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { } func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) + return newMinPayload(30) +} + +func newMinPayload(minSize int) []byte { + b := make([]byte, minSize+rand.Intn(100)) for i := range b { b[i] = byte(rand.Intn(256)) } return b } -func TestBindPortReuse(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - c.createEndpoint(ipv6.ProtocolNumber) - - var eps [5]tcpip.Endpoint - reusePortOpt := tcpip.ReusePortOption(1) - - pollChannel := make(chan tcpip.Endpoint) - for i := 0; i < len(eps); i++ { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(eps[i]) - - defer eps[i].Close() - if err := eps[i].SetSockOpt(reusePortOpt); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) } + defer ep.Close() - npackets := 100000 - nports := 10000 - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.injectV6Packet(payload, &header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - }) + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled - // by the same socket. - if ports[port] != ep { - t.Fatalf("Port mismatch") - } - } + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) } - if len(stats) != len(eps) { - t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s } - // Check that a packet distribution is fair between sockets. - for _, c := range stats { - n := float64(npackets) / float64(len(eps)) - // The deviation is less than 10%. - if math.Abs(float64(c)-n) > n/10 { - t.Fatal(c, n) - } + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) } } -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// testReadInternal sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then attempts to read it from the +// UDP endpoint and depending on if this was expected to succeed verifies its +// correctness. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { c.t.Helper() payload := newPayload() @@ -566,6 +557,9 @@ func testRead(c *testContext, flow testFlow) { c.wq.EventRegister(&we, waiter.EventIn) defer c.wq.EventUnregister(&we) + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + var addr tcpip.FullAddress v, _, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { @@ -573,25 +567,55 @@ func testRead(c *testContext, flow testFlow) { select { case <-ch: v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") + case <-time.After(300 * time.Millisecond): + if packetShouldBeDropped { + return // expected to time out + } + c.t.Fatal("timed out waiting for data") } } + if expectReadError && err != nil { + c.checkEndpointReadStats(1, epstats, err) + return + } + + if err != nil { + c.t.Fatal("Read failed:", err) + } + + if packetShouldBeDropped { + c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + } + // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) } // Check the payload. if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + c.checkEndpointReadStats(1, epstats, err) +} + +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) +} + +// testFailingRead sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then tries to read it from the UDP +// endpoint and expects this to fail. +func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { + c.t.Helper() + testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) } func TestBindEphemeralPort(t *testing.T) { @@ -763,13 +787,17 @@ func TestReadOnBoundToMulticast(t *testing.T) { c.t.Fatal("SetSockOpt failed:", err) } + // Check that we receive multicast packets but not unicast or broadcast + // ones. testRead(c, flow) + testFailingRead(c, broadcast, false /* expectReadError */) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and receive broadcast data on it. +// address and can receive only broadcast data. func TestV4ReadOnBoundToBroadcast(t *testing.T) { for _, flow := range []testFlow{broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -784,8 +812,31 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - // Test acceptance. + // Check that we receive broadcast packets but not unicast ones. testRead(c, flow) + testFailingRead(c, unicastV4, false /* expectReadError */) + }) + } +} + +// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY +// and receive broadcast and unicast data. +func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s (", err) + } + + // Check that we receive both broadcast and unicast packets. + testRead(c, flow) + testRead(c, unicastV4) }) } } @@ -794,7 +845,8 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { // and verifies it fails with the provided error code. func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { c.t.Helper() - + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) @@ -802,6 +854,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) + c.checkEndpointWriteStats(1, epstats, gotErr) if gotErr != wantErr { c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } @@ -827,6 +880,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() writeOpts := tcpip.WriteOptions{} if setDest { @@ -844,7 +899,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - + c.checkEndpointWriteStats(1, epstats, err) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -913,6 +968,10 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Write to V4 mapped address. testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + const want = 1 + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { + c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) + } } func TestDualWriteConnectedToV4Mapped(t *testing.T) { @@ -1175,6 +1234,109 @@ func TestTTL(t *testing.T) { } } +func TestSetTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() + } + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + ep.Close() + + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } + }) + } +} + +func TestTOSV4(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const tos = 0xC0 + var v tcpip.IPv4TOSOption + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + // Test for expected default value. + if v != 0 { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + } + + if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { + c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + } + + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv4TOSOption(tos); v != want { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testWrite(c, flow, checker.TOS(tos, 0)) + }) + } +} + +func TestTOSV6(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const tos = 0xC0 + var v tcpip.IPv6TrafficClassOption + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + // Test for expected default value. + if v != 0 { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + } + + if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { + c.t.Errorf("SetSockOpt failed: %s", err) + } + + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv6TrafficClassOption(tos); v != want { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testWrite(c, flow, checker.TOS(tos, 0)) + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1238,3 +1400,267 @@ func TestMulticastInterfaceOption(t *testing.T) { }) } } + +// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV4UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true, will result in a payload large enough + // so that the final generated IPv4 packet is larger than + // header.IPv4MinimumProcessableDatagramSize. + largePayload bool + }{ + {unicastV4, true, false}, + {unicastV4, true, true}, + {multicastV4, false, false}, + {multicastV4, false, true}, + {broadcast, false, false}, + {broadcast, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(576) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + select { + case p := <-c.linkEP.C: + t.Fatalf("unexpected packet received: %+v", p) + case <-time.After(1 * time.Second): + return + } + } + + select { + case p := <-c.linkEP.C: + var pkt []byte + pkt = append(pkt, p.Header...) + pkt = append(pkt, p.Payload...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } + + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) + } + case <-time.After(1 * time.Second): + t.Fatalf("packet wasn't written out") + } + }) + } +} + +// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV6UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true will result in a payload large enough to + // create an IPv6 packet > header.IPv6MinimumMTU bytes. + largePayload bool + }{ + {unicastV6, true, false}, + {unicastV6, true, true}, + {multicastV6, false, false}, + {multicastV6, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(1280) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + select { + case p := <-c.linkEP.C: + t.Fatalf("unexpected packet received: %+v", p) + case <-time.After(1 * time.Second): + return + } + } + + select { + case p := <-c.linkEP.C: + var pkt []byte + pkt = append(pkt, p.Header...) + pkt = append(pkt, p.Payload...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) + } + case <-time.After(1 * time.Second): + t.Fatalf("packet wasn't written out") + } + }) + } +} + +// TestIncrementMalformedPacketsReceived verifies if the malformed received +// global and endpoint stats get incremented. +func TestIncrementMalformedPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + payload := newPayload() + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + c.injectV6Packet(payload, &h, false /* !valid */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } +} + +// TestShutdownRead verifies endpoint read shutdown and error +// stats increment on packet receive. +func TestShutdownRead(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingRead(c, unicastV6, true /* expectReadError */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { + t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) + } +} + +// TestShutdownWrite verifies endpoint write shutdown and error +// stats increment on packet write. +func TestShutdownWrite(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) +} + +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil: + want.PacketsSent.IncrementBy(incr) + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + want.WriteErrors.InvalidArgs.IncrementBy(incr) + case tcpip.ErrClosedForSend: + want.WriteErrors.WriteClosed.IncrementBy(incr) + case tcpip.ErrInvalidEndpointState: + want.WriteErrors.InvalidEndpointState.IncrementBy(incr) + case tcpip.ErrNoLinkAddress: + want.SendErrors.NoLinkAddr.IncrementBy(incr) + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + want.SendErrors.NoRoute.IncrementBy(incr) + default: + want.SendErrors.SendToNetworkFailed.IncrementBy(incr) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} + +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil, tcpip.ErrWouldBlock: + case tcpip.ErrClosedForReceive: + want.ReadErrors.ReadClosed.IncrementBy(incr) + default: + c.t.Errorf("Endpoint error missing stats update err %v", err) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 98d51cc69..6afdb29b7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index cbd92fc05..8f6f180e5 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b7f505a84..b6bbb0ea2 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 9173dfd0f..1f7efb064 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "waiter_list", |