diff options
Diffstat (limited to 'pkg')
466 files changed, 9184 insertions, 5035 deletions
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index 0d921ed6f..cad24fcc7 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -19,8 +19,10 @@ package linux // See linux/magic.h. const ( ANON_INODE_FS_MAGIC = 0x09041934 + CGROUP_SUPER_MAGIC = 0x27e0eb DEVPTS_SUPER_MAGIC = 0x00001cd1 EXT_SUPER_MAGIC = 0xef53 + FUSE_SUPER_MAGIC = 0x65735546 OVERLAYFS_SUPER_MAGIC = 0x794c7630 PIPEFS_MAGIC = 0x50495045 PROC_SUPER_MAGIC = 0x9fa0 @@ -29,7 +31,6 @@ const ( SYSFS_MAGIC = 0x62656572 TMPFS_MAGIC = 0x01021994 V9FS_MAGIC = 0x01021997 - FUSE_SUPER_MAGIC = 0x65735546 ) // Filesystem path limits, from uapi/linux/limits.h. diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 378f1baf3..35c632168 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -145,13 +145,13 @@ func (ke *KernelIPTEntry) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { - ke.Entry.MarshalBytes(dst) + ke.Entry.MarshalUnsafe(dst) ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { - ke.Entry.UnmarshalBytes(src) + ke.Entry.UnmarshalUnsafe(src) ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } @@ -375,6 +375,17 @@ type XTRedirectTarget struct { // SizeOfXTRedirectTarget is the size of an XTRedirectTarget. const SizeOfXTRedirectTarget = 56 +// XTSNATTarget triggers Source NAT when reached. +// Adding 4 bytes of padding to make the struct 8 byte aligned. +type XTSNATTarget struct { + Target XTEntryTarget + NfRange NfNATIPV4MultiRangeCompat + _ [4]byte +} + +// SizeOfXTSNATTarget is the size of an XTSNATTarget. +const SizeOfXTSNATTarget = 56 + // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. // @@ -429,7 +440,7 @@ func (ke *KernelIPTGetEntries) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { - ke.IPTGetEntries.MarshalBytes(dst) + ke.IPTGetEntries.MarshalUnsafe(dst) marshalledUntil := ke.IPTGetEntries.SizeBytes() for i := range ke.Entrytable { ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) @@ -439,7 +450,7 @@ func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { - ke.IPTGetEntries.UnmarshalBytes(src) + ke.IPTGetEntries.UnmarshalUnsafe(src) unmarshalledUntil := ke.IPTGetEntries.SizeBytes() for i := range ke.Entrytable { ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index b953e62dc..f7c70b430 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -86,7 +86,7 @@ func (ke *KernelIP6TGetEntries) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) { - ke.IPTGetEntries.MarshalBytes(dst) + ke.IPTGetEntries.MarshalUnsafe(dst) marshalledUntil := ke.IPTGetEntries.SizeBytes() for i := range ke.Entrytable { ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) @@ -96,7 +96,7 @@ func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) { // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) { - ke.IPTGetEntries.UnmarshalBytes(src) + ke.IPTGetEntries.UnmarshalUnsafe(src) unmarshalledUntil := ke.IPTGetEntries.SizeBytes() for i := range ke.Entrytable { ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) @@ -149,8 +149,8 @@ type IP6TEntry struct { const SizeOfIP6TEntry = 168 // KernelIP6TEntry is identical to IP6TEntry, but includes the Elems field. -// KernelIP6TEntry itself is not Marshallable but it implements some methods of -// marshal.Marshallable that help in other implementations of Marshallable. +// +// +marshal dynamic type KernelIP6TEntry struct { Entry IP6TEntry @@ -168,13 +168,13 @@ func (ke *KernelIP6TEntry) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) { - ke.Entry.MarshalBytes(dst) + ke.Entry.MarshalUnsafe(dst) ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) { - ke.Entry.UnmarshalBytes(src) + ke.Entry.UnmarshalUnsafe(src) ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go index 50e22fe7e..e722971f1 100644 --- a/pkg/abi/linux/ptrace_amd64.go +++ b/pkg/abi/linux/ptrace_amd64.go @@ -61,3 +61,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 { func (p *PtraceRegs) StackPointer() uint64 { return p.Rsp } + +// SetStackPointer sets the stack pointer to the specified value. +func (p *PtraceRegs) SetStackPointer(sp uint64) { + p.Rsp = sp +} diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go index da36811d2..3d0906565 100644 --- a/pkg/abi/linux/ptrace_arm64.go +++ b/pkg/abi/linux/ptrace_arm64.go @@ -38,3 +38,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 { func (p *PtraceRegs) StackPointer() uint64 { return p.Sp } + +// SetStackPointer sets the stack pointer to the specified value. +func (p *PtraceRegs) SetStackPointer(sp uint64) { + p.Sp = sp +} diff --git a/pkg/coverage/BUILD b/pkg/coverage/BUILD index a198e8028..ace5895f8 100644 --- a/pkg/coverage/BUILD +++ b/pkg/coverage/BUILD @@ -7,8 +7,8 @@ go_library( srcs = ["coverage.go"], visibility = ["//:sandbox"], deps = [ + "//pkg/hostarch", "//pkg/sync", - "//pkg/usermem", "@io_bazel_rules_go//go/tools/coverdata", ], ) diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go index 6f3d72e83..b33a20802 100644 --- a/pkg/coverage/coverage.go +++ b/pkg/coverage/coverage.go @@ -26,20 +26,25 @@ import ( "fmt" "io" "sort" + "sync/atomic" "testing" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" "github.com/bazelbuild/rules_go/go/tools/coverdata" ) -// coverageMu must be held while accessing coverdata.Cover. This prevents -// concurrent reads/writes from multiple threads collecting coverage data. -var coverageMu sync.RWMutex +var ( + // coverageMu must be held while accessing coverdata.Cover. This prevents + // concurrent reads/writes from multiple threads collecting coverage data. + coverageMu sync.RWMutex -// once ensures that globalData is only initialized once. -var once sync.Once + // reportOutput is the place to write out a coverage report. It should be + // closed after the report is written. It is protected by reportOutputMu. + reportOutput io.WriteCloser + reportOutputMu sync.Mutex +) // blockBitLength is the number of bits used to represent coverage block index // in a synthetic PC (the rest are used to represent the file index). Even @@ -51,12 +56,26 @@ var once sync.Once // file and every block. const blockBitLength = 16 -// KcovAvailable returns whether the kcov coverage interface is available. It is -// available as long as coverage is enabled for some files. -func KcovAvailable() bool { +// Available returns whether any coverage data is available. +func Available() bool { return len(coverdata.Cover.Blocks) > 0 } +// EnableReport sets up coverage reporting. +func EnableReport(w io.WriteCloser) { + reportOutputMu.Lock() + defer reportOutputMu.Unlock() + reportOutput = w +} + +// KcovSupported returns whether the kcov interface should be made available. +// +// If coverage reporting is on, do not turn on kcov, which will consume +// coverage data. +func KcovSupported() bool { + return (reportOutput == nil) && Available() +} + var globalData struct { // files is the set of covered files sorted by filename. It is calculated at // startup. @@ -65,6 +84,9 @@ var globalData struct { // syntheticPCs are a set of PCs calculated at startup, where the PC // at syntheticPCs[i][j] corresponds to file i, block j. syntheticPCs [][]uint64 + + // once ensures that globalData is only initialized once. + once sync.Once } // ClearCoverageData clears existing coverage data. @@ -141,7 +163,7 @@ func ConsumeCoverageData(w io.Writer) int { // Non-zero coverage data found; consume it and report as a PC. counters[index] = 0 pc := globalData.syntheticPCs[fileNum][index] - usermem.ByteOrder.PutUint64(pcBuffer[:], pc) + hostarch.ByteOrder.PutUint64(pcBuffer[:], pc) n, err := w.Write(pcBuffer[:]) if err != nil { if err == io.EOF { @@ -166,7 +188,7 @@ func ConsumeCoverageData(w io.Writer) int { // InitCoverageData initializes globalData. It should be called before any kcov // data is written. func InitCoverageData() { - once.Do(func() { + globalData.once.Do(func() { // First, order all files. Then calculate synthetic PCs for every block // (using the well-defined ordering for files as well). for file := range coverdata.Cover.Blocks { @@ -185,6 +207,38 @@ func InitCoverageData() { }) } +// reportOnce ensures that a coverage report is written at most once. For a +// complete coverage report, Report should be called during the sandbox teardown +// process. Report is called from multiple places (which may overlap) so that a +// coverage report is written in different sandbox exit scenarios. +var reportOnce sync.Once + +// Report writes out a coverage report with all blocks that have been covered. +// +// TODO(b/144576401): Decide whether this should actually be in LCOV format +func Report() error { + if reportOutput == nil { + return nil + } + + var err error + reportOnce.Do(func() { + for file, counters := range coverdata.Cover.Counters { + blocks := coverdata.Cover.Blocks[file] + for i := 0; i < len(counters); i++ { + if atomic.LoadUint32(&counters[i]) > 0 { + err = writeBlock(reportOutput, file, blocks[i]) + if err != nil { + return + } + } + } + } + reportOutput.Close() + }) + return err +} + // Symbolize prints information about the block corresponding to pc. func Symbolize(out io.Writer, pc uint64) error { fileNum, blockNum := syntheticPCToIndexes(pc) @@ -196,18 +250,32 @@ func Symbolize(out io.Writer, pc uint64) error { if err != nil { return err } - writeBlock(out, pc, file, block) - return nil + return writeBlockWithPC(out, pc, file, block) } // WriteAllBlocks prints all information about all blocks along with their // corresponding synthetic PCs. -func WriteAllBlocks(out io.Writer) { +func WriteAllBlocks(out io.Writer) error { for fileNum, file := range globalData.files { for blockNum, block := range coverdata.Cover.Blocks[file] { - writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block) + if err := writeBlockWithPC(out, calculateSyntheticPC(fileNum, blockNum), file, block); err != nil { + return err + } } } + return nil +} + +func writeBlockWithPC(out io.Writer, pc uint64, file string, block testing.CoverBlock) error { + if _, err := io.WriteString(out, fmt.Sprintf("%#x\n", pc)); err != nil { + return err + } + return writeBlock(out, file, block) +} + +func writeBlock(out io.Writer, file string, block testing.CoverBlock) error { + _, err := io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1)) + return err } func calculateSyntheticPC(fileNum int, blockNum int) uint64 { @@ -239,8 +307,3 @@ func blockFromIndex(file string, i int) (testing.CoverBlock, error) { } return blocks[i], nil } - -func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) { - io.WriteString(out, fmt.Sprintf("%#x\n", pc)) - io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1)) -} diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD index 35683fe98..b4e05f922 100644 --- a/pkg/gohacks/BUILD +++ b/pkg/gohacks/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,3 +10,11 @@ go_library( stateify = False, visibility = ["//:sandbox"], ) + +go_test( + name = "gohacks_test", + size = "small", + srcs = ["gohacks_test.go"], + library = ":gohacks", + deps = ["@org_golang_x_sys//unix:go_default_library"], +) diff --git a/pkg/gohacks/gohacks_test.go b/pkg/gohacks/gohacks_test.go new file mode 100644 index 000000000..e18c8abc7 --- /dev/null +++ b/pkg/gohacks/gohacks_test.go @@ -0,0 +1,97 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gohacks + +import ( + "io/ioutil" + "math/rand" + "os" + "runtime/debug" + "testing" + + "golang.org/x/sys/unix" +) + +func randBuf(size int) []byte { + b := make([]byte, size) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +// Size of a page in bytes. Cloned from hostarch.PageSize to avoid a circular +// dependency. +const pageSize = 4096 + +func testCopy(dst, src []byte) (panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + debug.SetPanicOnFault(true) + copy(dst, src) + return panicked +} + +func TestSegVOnMemmove(t *testing.T) { + // Test that SIGSEGVs received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer unix.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} + +func TestSigbusOnMemmove(t *testing.T) { + // Test that SIGBUS received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + f, err := ioutil.TempFile("", "sigbus_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + os.Remove(f.Name()) + defer f.Close() + + a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer unix.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go index 10bbb1f58..374aac2b4 100644 --- a/pkg/gohacks/gohacks_unsafe.go +++ b/pkg/gohacks/gohacks_unsafe.go @@ -75,3 +75,17 @@ func StringFromImmutableBytes(bs []byte) string { // strings.Builder.String(). return *(*string)(unsafe.Pointer(&bs)) } + +// Note that go:linkname silently doesn't work if the local name is exported, +// necessitating an indirection for exported functions. + +// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>. +// +//go:nosplit +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) diff --git a/pkg/hostarch/BUILD b/pkg/hostarch/BUILD new file mode 100644 index 000000000..1e8def4d9 --- /dev/null +++ b/pkg/hostarch/BUILD @@ -0,0 +1,42 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "addr_range", + out = "addr_range.go", + package = "hostarch", + prefix = "Addr", + template = "//pkg/segment:generic_range", + types = { + "T": "Addr", + }, +) + +go_test( + name = "hostarch_test", + size = "small", + srcs = [ + "addr_range_seq_test.go", + ], + library = ":hostarch", +) + +go_library( + name = "hostarch", + srcs = [ + "access_type.go", + "addr.go", + "addr_range.go", + "addr_range_seq_unsafe.go", + "hostarch.go", + "hostarch_arm64.go", + "hostarch_x86.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/gohacks", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/usermem/access_type.go b/pkg/hostarch/access_type.go index 2cfca29af..e30476840 100644 --- a/pkg/usermem/access_type.go +++ b/pkg/hostarch/access_type.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package usermem +package hostarch import "golang.org/x/sys/unix" diff --git a/pkg/usermem/addr.go b/pkg/hostarch/addr.go index c4100481e..0cf0f3c81 100644 --- a/pkg/usermem/addr.go +++ b/pkg/hostarch/addr.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package usermem +package hostarch import ( "fmt" @@ -57,7 +57,7 @@ func (v Addr) RoundUp() (addr Addr, ok bool) { func (v Addr) MustRoundUp() Addr { addr, ok := v.RoundUp() if !ok { - panic(fmt.Sprintf("usermem.Addr(%d).RoundUp() wraps", v)) + panic(fmt.Sprintf("hostarch.Addr(%d).RoundUp() wraps", v)) } return addr } diff --git a/pkg/usermem/addr_range_seq_test.go b/pkg/hostarch/addr_range_seq_test.go index 82f735026..5726dfd19 100644 --- a/pkg/usermem/addr_range_seq_test.go +++ b/pkg/hostarch/addr_range_seq_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package usermem +package hostarch import ( "testing" diff --git a/pkg/usermem/addr_range_seq_unsafe.go b/pkg/hostarch/addr_range_seq_unsafe.go index c9a1415a0..ecc17d595 100644 --- a/pkg/usermem/addr_range_seq_unsafe.go +++ b/pkg/hostarch/addr_range_seq_unsafe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package usermem +package hostarch import ( "bytes" diff --git a/pkg/hostarch/hostarch.go b/pkg/hostarch/hostarch.go new file mode 100644 index 000000000..fdd29c567 --- /dev/null +++ b/pkg/hostarch/hostarch.go @@ -0,0 +1,7 @@ +// Copyright 2021 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package hostarch contains host arch address operations for user memory. +package hostarch diff --git a/pkg/usermem/usermem_arm64.go b/pkg/hostarch/hostarch_arm64.go index 7e7529585..a31a8aeeb 100644 --- a/pkg/usermem/usermem_arm64.go +++ b/pkg/hostarch/hostarch_arm64.go @@ -14,7 +14,7 @@ // +build arm64 -package usermem +package hostarch import ( "encoding/binary" diff --git a/pkg/usermem/usermem_x86.go b/pkg/hostarch/hostarch_x86.go index d96f829fb..af6ef2b7f 100644 --- a/pkg/usermem/usermem_x86.go +++ b/pkg/hostarch/hostarch_x86.go @@ -14,7 +14,7 @@ // +build amd64 386 -package usermem +package hostarch import "encoding/binary" diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD index aac0161fa..7cd89e639 100644 --- a/pkg/marshal/BUILD +++ b/pkg/marshal/BUILD @@ -11,5 +11,5 @@ go_library( visibility = [ "//:sandbox", ], - deps = ["//pkg/usermem"], + deps = ["//pkg/hostarch"], ) diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go index d8cb44b40..7da450ce8 100644 --- a/pkg/marshal/marshal.go +++ b/pkg/marshal/marshal.go @@ -23,7 +23,7 @@ package marshal import ( "io" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // CopyContext defines the memory operations required to marshal to and from @@ -36,11 +36,11 @@ type CopyContext interface { // CopyOutBytes writes the contents of b to the task's memory. See // kernel.CopyOutBytes. - CopyOutBytes(addr usermem.Addr, b []byte) (int, error) + CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) // CopyInBytes reads the contents of the task's memory to b. See // kernel.CopyInBytes. - CopyInBytes(addr usermem.Addr, b []byte) (int, error) + CopyInBytes(addr hostarch.Addr, b []byte) (int, error) } // Marshallable represents operations on a type that can be marshalled to and @@ -108,7 +108,7 @@ type Marshallable interface { // If the copy-in from the task memory is only partially successful, CopyIn // should still attempt to deserialize as much data as possible. See comment // for UnmarshalBytes. - CopyIn(cc CopyContext, addr usermem.Addr) (int, error) + CopyIn(cc CopyContext, addr hostarch.Addr) (int, error) // CopyOut serializes a Marshallable type to a task's memory. This may only // be called from a task goroutine. This is more efficient than calling @@ -119,7 +119,7 @@ type Marshallable interface { // The copy-out to the task memory may be partially successful, in which // case CopyOut returns how much data was serialized. See comment for // MarshalBytes for implications. - CopyOut(cc CopyContext, addr usermem.Addr) (int, error) + CopyOut(cc CopyContext, addr hostarch.Addr) (int, error) // CopyOutN is like CopyOut, but explicitly requests a partial // copy-out. Note that this may yield unexpected results for non-packed @@ -127,7 +127,7 @@ type Marshallable interface { // comment on MarshalBytes. // // The limit must be less than or equal to SizeBytes(). - CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) + CopyOutN(cc CopyContext, addr hostarch.Addr, limit int) (int, error) } // go-marshal generates additional functions for a type based on additional @@ -157,15 +157,18 @@ type Marshallable interface { // func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... } // // // CopyFooSliceIn copies in a slice of Foo objects from the task's memory. -// func CopyFooSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Foo) (int, error) { ... } +// func CopyFooSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []Foo) (int, error) { ... } // // // CopyFooSliceIn copies out a slice of Foo objects to the task's memory. -// func CopyFooSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []Foo) (int, error) { ... } +// func CopyFooSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []Foo) (int, error) { ... } // // The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where // %s is the first argument to the slice clause. This directive is not supported // for newtypes on arrays. // +// Note: Partial copies are not supported for Slice API UnmarshalUnsafe and +// MarshalUnsafe. +// // The slice clause also takes an optional second argument, which must be the // value "inner": // @@ -175,10 +178,10 @@ type Marshallable interface { // This is only valid on newtypes on primitives, and causes the generated // functions to accept slices of the inner type instead: // -// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []int32) (int, error) { ... } +// func CopyInt32SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []int32) (int, error) { ... } // // Without "inner", they would instead be: // -// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Int32) (int, error) { ... } +// func CopyInt32SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []Int32) (int, error) { ... } // // This may help avoid a cast depending on how the generated functions are used. diff --git a/pkg/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go index ea75e09f2..9e6a6fa29 100644 --- a/pkg/marshal/marshal_impl_util.go +++ b/pkg/marshal/marshal_impl_util.go @@ -17,7 +17,7 @@ package marshal import ( "io" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // StubMarshallable implements the Marshallable interface. @@ -63,16 +63,16 @@ func (StubMarshallable) UnmarshalUnsafe(src []byte) { } // CopyIn implements Marshallable.CopyIn. -func (StubMarshallable) CopyIn(cc CopyContext, addr usermem.Addr) (int, error) { +func (StubMarshallable) CopyIn(cc CopyContext, addr hostarch.Addr) (int, error) { panic("Please implement your own CopyIn function") } // CopyOut implements Marshallable.CopyOut. -func (StubMarshallable) CopyOut(cc CopyContext, addr usermem.Addr) (int, error) { +func (StubMarshallable) CopyOut(cc CopyContext, addr hostarch.Addr) (int, error) { panic("Please implement your own CopyOut function") } // CopyOutN implements Marshallable.CopyOutN. -func (StubMarshallable) CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) { +func (StubMarshallable) CopyOutN(cc CopyContext, addr hostarch.Addr, limit int) (int, error) { panic("Please implement your own CopyOutN function") } diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD index d77a11c79..190b57c29 100644 --- a/pkg/marshal/primitive/BUILD +++ b/pkg/marshal/primitive/BUILD @@ -13,6 +13,7 @@ go_library( ], deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/marshal", "//pkg/usermem", ], diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index 4b342de6b..32c8ed138 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -20,6 +20,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/usermem" ) @@ -102,17 +103,17 @@ func (b *ByteSlice) UnmarshalUnsafe(src []byte) { } // CopyIn implements marshal.Marshallable.CopyIn. -func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { +func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { return cc.CopyInBytes(addr, *b) } // CopyOut implements marshal.Marshallable.CopyOut. -func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { +func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { return cc.CopyOutBytes(addr, *b) } // CopyOutN implements marshal.Marshallable.CopyOutN. -func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { +func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { return cc.CopyOutBytes(addr, (*b)[:limit]) } @@ -131,7 +132,7 @@ var _ marshal.Marshallable = (*ByteSlice)(nil) // CopyInt8In is a convenient wrapper for copying in an int8 from the task's // memory. -func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, error) { +func CopyInt8In(cc marshal.CopyContext, addr hostarch.Addr, dst *int8) (int, error) { var buf Int8 n, err := buf.CopyIn(cc, addr) if err != nil { @@ -143,14 +144,14 @@ func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, erro // CopyInt8Out is a convenient wrapper for copying out an int8 to the task's // memory. -func CopyInt8Out(cc marshal.CopyContext, addr usermem.Addr, src int8) (int, error) { +func CopyInt8Out(cc marshal.CopyContext, addr hostarch.Addr, src int8) (int, error) { srcP := Int8(src) return srcP.CopyOut(cc, addr) } // CopyUint8In is a convenient wrapper for copying in a uint8 from the task's // memory. -func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, error) { +func CopyUint8In(cc marshal.CopyContext, addr hostarch.Addr, dst *uint8) (int, error) { var buf Uint8 n, err := buf.CopyIn(cc, addr) if err != nil { @@ -162,7 +163,7 @@ func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, er // CopyUint8Out is a convenient wrapper for copying out a uint8 to the task's // memory. -func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, error) { +func CopyUint8Out(cc marshal.CopyContext, addr hostarch.Addr, src uint8) (int, error) { srcP := Uint8(src) return srcP.CopyOut(cc, addr) } @@ -171,7 +172,7 @@ func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, er // CopyInt16In is a convenient wrapper for copying in an int16 from the task's // memory. -func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, error) { +func CopyInt16In(cc marshal.CopyContext, addr hostarch.Addr, dst *int16) (int, error) { var buf Int16 n, err := buf.CopyIn(cc, addr) if err != nil { @@ -183,14 +184,14 @@ func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, er // CopyInt16Out is a convenient wrapper for copying out an int16 to the task's // memory. -func CopyInt16Out(cc marshal.CopyContext, addr usermem.Addr, src int16) (int, error) { +func CopyInt16Out(cc marshal.CopyContext, addr hostarch.Addr, src int16) (int, error) { srcP := Int16(src) return srcP.CopyOut(cc, addr) } // CopyUint16In is a convenient wrapper for copying in a uint16 from the task's // memory. -func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int, error) { +func CopyUint16In(cc marshal.CopyContext, addr hostarch.Addr, dst *uint16) (int, error) { var buf Uint16 n, err := buf.CopyIn(cc, addr) if err != nil { @@ -202,7 +203,7 @@ func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int, // CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's // memory. -func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int, error) { +func CopyUint16Out(cc marshal.CopyContext, addr hostarch.Addr, src uint16) (int, error) { srcP := Uint16(src) return srcP.CopyOut(cc, addr) } @@ -211,7 +212,7 @@ func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int, // CopyInt32In is a convenient wrapper for copying in an int32 from the task's // memory. -func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, error) { +func CopyInt32In(cc marshal.CopyContext, addr hostarch.Addr, dst *int32) (int, error) { var buf Int32 n, err := buf.CopyIn(cc, addr) if err != nil { @@ -223,14 +224,14 @@ func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, er // CopyInt32Out is a convenient wrapper for copying out an int32 to the task's // memory. -func CopyInt32Out(cc marshal.CopyContext, addr usermem.Addr, src int32) (int, error) { +func CopyInt32Out(cc marshal.CopyContext, addr hostarch.Addr, src int32) (int, error) { srcP := Int32(src) return srcP.CopyOut(cc, addr) } // CopyUint32In is a convenient wrapper for copying in a uint32 from the task's // memory. -func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int, error) { +func CopyUint32In(cc marshal.CopyContext, addr hostarch.Addr, dst *uint32) (int, error) { var buf Uint32 n, err := buf.CopyIn(cc, addr) if err != nil { @@ -242,7 +243,7 @@ func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int, // CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's // memory. -func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int, error) { +func CopyUint32Out(cc marshal.CopyContext, addr hostarch.Addr, src uint32) (int, error) { srcP := Uint32(src) return srcP.CopyOut(cc, addr) } @@ -251,7 +252,7 @@ func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int, // CopyInt64In is a convenient wrapper for copying in an int64 from the task's // memory. -func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, error) { +func CopyInt64In(cc marshal.CopyContext, addr hostarch.Addr, dst *int64) (int, error) { var buf Int64 n, err := buf.CopyIn(cc, addr) if err != nil { @@ -263,14 +264,14 @@ func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, er // CopyInt64Out is a convenient wrapper for copying out an int64 to the task's // memory. -func CopyInt64Out(cc marshal.CopyContext, addr usermem.Addr, src int64) (int, error) { +func CopyInt64Out(cc marshal.CopyContext, addr hostarch.Addr, src int64) (int, error) { srcP := Int64(src) return srcP.CopyOut(cc, addr) } // CopyUint64In is a convenient wrapper for copying in a uint64 from the task's // memory. -func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int, error) { +func CopyUint64In(cc marshal.CopyContext, addr hostarch.Addr, dst *uint64) (int, error) { var buf Uint64 n, err := buf.CopyIn(cc, addr) if err != nil { @@ -282,14 +283,14 @@ func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int, // CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's // memory. -func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int, error) { +func CopyUint64Out(cc marshal.CopyContext, addr hostarch.Addr, src uint64) (int, error) { srcP := Uint64(src) return srcP.CopyOut(cc, addr) } // CopyByteSliceIn is a convenient wrapper for copying in a []byte from the // task's memory. -func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (int, error) { +func CopyByteSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst *[]byte) (int, error) { var buf ByteSlice n, err := buf.CopyIn(cc, addr) if err != nil { @@ -301,14 +302,14 @@ func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (in // CopyByteSliceOut is a convenient wrapper for copying out a []byte to the // task's memory. -func CopyByteSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []byte) (int, error) { +func CopyByteSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []byte) (int, error) { srcP := ByteSlice(src) return srcP.CopyOut(cc, addr) } // CopyStringIn is a convenient wrapper for copying in a string from the // task's memory. -func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int, error) { +func CopyStringIn(cc marshal.CopyContext, addr hostarch.Addr, dst *string) (int, error) { var buf ByteSlice n, err := buf.CopyIn(cc, addr) if err != nil { @@ -320,12 +321,12 @@ func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int, // CopyStringOut is a convenient wrapper for copying out a string to the task's // memory. -func CopyStringOut(cc marshal.CopyContext, addr usermem.Addr, src string) (int, error) { +func CopyStringOut(cc marshal.CopyContext, addr hostarch.Addr, src string) (int, error) { srcP := ByteSlice(src) return srcP.CopyOut(cc, addr) } -// IOCopyContext wraps an object implementing usermem.IO to implement +// IOCopyContext wraps an object implementing hostarch.IO to implement // marshal.CopyContext. type IOCopyContext struct { Ctx context.Context @@ -339,11 +340,11 @@ func (i *IOCopyContext) CopyScratchBuffer(size int) []byte { } // CopyOutBytes implements marshal.CopyContext.CopyOutBytes. -func (i *IOCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { +func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) { return i.IO.CopyOut(i.Ctx, addr, b, i.Opts) } // CopyInBytes implements marshal.CopyContext.CopyInBytes. -func (i *IOCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { +func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) { return i.IO.CopyIn(i.Ctx, addr, b, i.Opts) } diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD index 501a9ef21..dcd6c3bf5 100644 --- a/pkg/merkletree/BUILD +++ b/pkg/merkletree/BUILD @@ -8,7 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/usermem", + "//pkg/hostarch", ], ) @@ -18,6 +18,6 @@ go_test( library = ":merkletree", deps = [ "//pkg/abi/linux", - "//pkg/usermem", + "//pkg/hostarch", ], ) diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index d7209ace3..6450f664c 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -24,7 +24,8 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -65,7 +66,7 @@ type Layout struct { // of a tree. dataSize specifies the size of input data in bytes. func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) { layout := Layout{ - blockSize: usermem.PageSize, + blockSize: hostarch.PageSize, } // TODO(b/156980949): Allow config SHA384. @@ -237,6 +238,7 @@ func Generate(params *GenerateParams) ([]byte, error) { Mode: params.Mode, UID: params.UID, GID: params.GID, + Children: params.Children, SymlinkTarget: params.SymlinkTarget, } diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index ed332b3f1..5d6f8df1b 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -24,7 +24,8 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) func TestLayout(t *testing.T) { @@ -58,7 +59,7 @@ func TestLayout(t *testing.T) { hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, expectedDigestSize: 32, - expectedLevelOffset: []int64{usermem.PageSize}, + expectedLevelOffset: []int64{hostarch.PageSize}, }, { name: "SmallSizeSHA512SameFile", @@ -66,7 +67,7 @@ func TestLayout(t *testing.T) { hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, expectedDigestSize: 64, - expectedLevelOffset: []int64{usermem.PageSize}, + expectedLevelOffset: []int64{hostarch.PageSize}, }, { name: "MiddleSizeSHA256SeparateFile", @@ -74,7 +75,7 @@ func TestLayout(t *testing.T) { hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, expectedDigestSize: 32, - expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, + expectedLevelOffset: []int64{0, 2 * hostarch.PageSize, 3 * hostarch.PageSize}, }, { name: "MiddleSizeSHA512SeparateFile", @@ -82,7 +83,7 @@ func TestLayout(t *testing.T) { hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, expectedDigestSize: 64, - expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize}, + expectedLevelOffset: []int64{0, 4 * hostarch.PageSize, 5 * hostarch.PageSize}, }, { name: "MiddleSizeSHA256SameFile", @@ -90,7 +91,7 @@ func TestLayout(t *testing.T) { hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, expectedDigestSize: 32, - expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, + expectedLevelOffset: []int64{245 * hostarch.PageSize, 247 * hostarch.PageSize, 248 * hostarch.PageSize}, }, { name: "MiddleSizeSHA512SameFile", @@ -98,39 +99,39 @@ func TestLayout(t *testing.T) { hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, expectedDigestSize: 64, - expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize}, + expectedLevelOffset: []int64{245 * hostarch.PageSize, 249 * hostarch.PageSize, 250 * hostarch.PageSize}, }, { name: "LargeSizeSHA256SeparateFile", - dataSize: 4096 * int64(usermem.PageSize), + dataSize: 4096 * int64(hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, expectedDigestSize: 32, - expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, + expectedLevelOffset: []int64{0, 32 * hostarch.PageSize, 33 * hostarch.PageSize}, }, { name: "LargeSizeSHA512SeparateFile", - dataSize: 4096 * int64(usermem.PageSize), + dataSize: 4096 * int64(hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, expectedDigestSize: 64, - expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize}, + expectedLevelOffset: []int64{0, 64 * hostarch.PageSize, 65 * hostarch.PageSize}, }, { name: "LargeSizeSHA256SameFile", - dataSize: 4096 * int64(usermem.PageSize), + dataSize: 4096 * int64(hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, expectedDigestSize: 32, - expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, + expectedLevelOffset: []int64{4096 * hostarch.PageSize, 4128 * hostarch.PageSize, 4129 * hostarch.PageSize}, }, { name: "LargeSizeSHA512SameFile", - dataSize: 4096 * int64(usermem.PageSize), + dataSize: 4096 * int64(hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, expectedDigestSize: 64, - expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize}, + expectedLevelOffset: []int64{4096 * hostarch.PageSize, 4160 * hostarch.PageSize, 4161 * hostarch.PageSize}, }, } @@ -140,8 +141,8 @@ func TestLayout(t *testing.T) { if err != nil { t.Fatalf("Failed to InitLayout: %v", err) } - if l.blockSize != int64(usermem.PageSize) { - t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) + if l.blockSize != int64(hostarch.PageSize) { + t.Errorf("Got blockSize %d, want %d", l.blockSize, hostarch.PageSize) } if l.digestSize != tc.expectedDigestSize { t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) @@ -202,56 +203,56 @@ func TestGenerate(t *testing.T) { }{ { name: "OnePageZeroesSHA256SeparateFile", - data: bytes.Repeat([]byte{0}, usermem.PageSize), + data: bytes.Repeat([]byte{0}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, expectedHash: []byte{9, 115, 238, 230, 38, 140, 195, 70, 207, 144, 202, 118, 23, 113, 32, 129, 226, 239, 177, 69, 161, 26, 14, 113, 16, 37, 30, 96, 19, 148, 132, 27}, }, { name: "OnePageZeroesSHA256SameFile", - data: bytes.Repeat([]byte{0}, usermem.PageSize), + data: bytes.Repeat([]byte{0}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, expectedHash: []byte{9, 115, 238, 230, 38, 140, 195, 70, 207, 144, 202, 118, 23, 113, 32, 129, 226, 239, 177, 69, 161, 26, 14, 113, 16, 37, 30, 96, 19, 148, 132, 27}, }, { name: "OnePageZeroesSHA512SeparateFile", - data: bytes.Repeat([]byte{0}, usermem.PageSize), + data: bytes.Repeat([]byte{0}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, expectedHash: []byte{127, 8, 95, 11, 83, 101, 51, 39, 170, 235, 39, 43, 135, 243, 145, 118, 148, 58, 27, 155, 182, 205, 44, 47, 5, 223, 215, 17, 35, 16, 43, 104, 43, 11, 8, 88, 171, 7, 249, 243, 14, 62, 126, 218, 23, 159, 237, 237, 42, 226, 39, 25, 87, 48, 253, 191, 116, 213, 37, 3, 187, 152, 154, 14}, }, { name: "OnePageZeroesSHA512SameFile", - data: bytes.Repeat([]byte{0}, usermem.PageSize), + data: bytes.Repeat([]byte{0}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, expectedHash: []byte{127, 8, 95, 11, 83, 101, 51, 39, 170, 235, 39, 43, 135, 243, 145, 118, 148, 58, 27, 155, 182, 205, 44, 47, 5, 223, 215, 17, 35, 16, 43, 104, 43, 11, 8, 88, 171, 7, 249, 243, 14, 62, 126, 218, 23, 159, 237, 237, 42, 226, 39, 25, 87, 48, 253, 191, 116, 213, 37, 3, 187, 152, 154, 14}, }, { name: "MultiplePageZeroesSHA256SeparateFile", - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, expectedHash: []byte{247, 158, 42, 215, 180, 106, 0, 28, 77, 64, 132, 162, 74, 65, 250, 161, 243, 66, 129, 44, 197, 8, 145, 14, 94, 206, 156, 184, 145, 145, 20, 185}, }, { name: "MultiplePageZeroesSHA256SameFile", - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, expectedHash: []byte{247, 158, 42, 215, 180, 106, 0, 28, 77, 64, 132, 162, 74, 65, 250, 161, 243, 66, 129, 44, 197, 8, 145, 14, 94, 206, 156, 184, 145, 145, 20, 185}, }, { name: "MultiplePageZeroesSHA512SeparateFile", - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, expectedHash: []byte{100, 121, 14, 30, 104, 200, 142, 182, 190, 78, 23, 68, 157, 174, 23, 75, 174, 250, 250, 25, 66, 45, 235, 103, 129, 49, 78, 127, 173, 154, 121, 35, 37, 115, 60, 217, 26, 205, 253, 253, 236, 145, 107, 109, 232, 19, 72, 92, 4, 191, 181, 205, 191, 57, 234, 177, 144, 235, 143, 30, 15, 197, 109, 81}, }, { name: "MultiplePageZeroesSHA512SameFile", - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, expectedHash: []byte{100, 121, 14, 30, 104, 200, 142, 182, 190, 78, 23, 68, 157, 174, 23, 75, 174, 250, 250, 25, 66, 45, 235, 103, 129, 49, 78, 127, 173, 154, 121, 35, 37, 115, 60, 217, 26, 205, 253, 253, 236, 145, 107, 109, 232, 19, 72, 92, 4, 191, 181, 205, 191, 57, 234, 177, 144, 235, 143, 30, 15, 197, 109, 81}, @@ -286,28 +287,28 @@ func TestGenerate(t *testing.T) { }, { name: "OnePageASHA256SeparateFile", - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, expectedHash: []byte{132, 54, 112, 142, 156, 19, 50, 140, 138, 240, 192, 154, 100, 120, 242, 69, 64, 217, 62, 166, 127, 88, 23, 197, 100, 66, 255, 215, 214, 229, 54, 1}, }, { name: "OnePageASHA256SameFile", - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, expectedHash: []byte{132, 54, 112, 142, 156, 19, 50, 140, 138, 240, 192, 154, 100, 120, 242, 69, 64, 217, 62, 166, 127, 88, 23, 197, 100, 66, 255, 215, 214, 229, 54, 1}, }, { name: "OnePageASHA512SeparateFile", - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, expectedHash: []byte{165, 46, 176, 116, 47, 209, 101, 193, 64, 185, 30, 9, 52, 22, 24, 154, 135, 220, 232, 168, 215, 45, 222, 226, 207, 104, 160, 10, 156, 98, 245, 250, 76, 21, 68, 204, 65, 118, 69, 52, 210, 155, 36, 109, 233, 103, 1, 40, 218, 89, 125, 38, 247, 194, 2, 225, 119, 155, 65, 99, 182, 111, 110, 145}, }, { name: "OnePageASHA512SameFile", - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, expectedHash: []byte{165, 46, 176, 116, 47, 209, 101, 193, 64, 185, 30, 9, 52, 22, 24, 154, 135, 220, 232, 168, 215, 45, 222, 226, 207, 104, 160, 10, 156, 98, 245, 250, 76, 21, 68, 204, 65, 118, 69, 52, 210, 155, 36, 109, 233, 103, 1, 40, 218, 89, 125, 38, 247, 194, 2, 225, 119, 155, 65, 99, 182, 111, 110, 145}, @@ -415,14 +416,14 @@ func TestVerifyInvalidRange(t *testing.T) { // Verify range starts outside data range. { name: "StartOutsideRange", - verifyStart: usermem.PageSize, + verifyStart: hostarch.PageSize, verifySize: 1, }, // Verify range ends outside data range. { name: "EndOutsideRange", verifyStart: 0, - verifySize: 2 * usermem.PageSize, + verifySize: 2 * hostarch.PageSize, }, // Verify range with negative size. { @@ -434,7 +435,7 @@ func TestVerifyInvalidRange(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, false /* isSymlink */, tc.verifyStart, tc.verifySize, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, false /* isSymlink */, tc.verifyStart, tc.verifySize, &buf) if _, err := Verify(¶ms); errors.Is(err, nil) { t.Errorf("Verification succeeded when expected to fail") } @@ -467,7 +468,7 @@ func TestVerifyUnmodifiedMetadata(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, tc.isSymlink, 0 /* verifyStart */, 0 /* verifySize */, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, tc.isSymlink, 0 /* verifyStart */, 0 /* verifySize */, &buf) if tc.isSymlink { params.SymlinkTarget = defaultSymlinkPath } @@ -495,7 +496,7 @@ func TestVerifyModifiedName(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) params.Name += "abc" if _, err := Verify(¶ms); errors.Is(err, nil) { t.Errorf("Verification succeeded when expected to fail") @@ -521,7 +522,7 @@ func TestVerifyModifiedSize(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) params.Size-- if _, err := Verify(¶ms); errors.Is(err, nil) { t.Errorf("Verification succeeded when expected to fail") @@ -547,7 +548,7 @@ func TestVerifyModifiedMode(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) params.Mode++ if _, err := Verify(¶ms); errors.Is(err, nil) { t.Errorf("Verification succeeded when expected to fail") @@ -573,7 +574,7 @@ func TestVerifyModifiedUID(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) params.UID++ if _, err := Verify(¶ms); errors.Is(err, nil) { t.Errorf("Verification succeeded when expected to fail") @@ -599,7 +600,7 @@ func TestVerifyModifiedGID(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) params.GID++ if _, err := Verify(¶ms); errors.Is(err, nil) { t.Errorf("Verification succeeded when expected to fail") @@ -625,7 +626,7 @@ func TestVerifyModifiedChildren(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) params.Children["abc"] = struct{}{} if _, err := Verify(¶ms); errors.Is(err, nil) { t.Errorf("Verification succeeded when expected to fail") @@ -636,7 +637,7 @@ func TestVerifyModifiedChildren(t *testing.T) { func TestVerifyModifiedSymlink(t *testing.T) { var buf bytes.Buffer - _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, true /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) + _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, true /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) params.SymlinkTarget = "merkle_modified_test_link" if _, err := Verify(¶ms); err == nil { t.Errorf("Verification succeeded when expected to fail") @@ -652,30 +653,30 @@ func TestModifyOutsideVerifyRange(t *testing.T) { }{ { name: "BeforeRangeSeparateFile", - modifyByte: 4*usermem.PageSize - 1, + modifyByte: 4*hostarch.PageSize - 1, dataAndTreeInSameFile: false, }, { name: "BeforeRangeSameFile", - modifyByte: 4*usermem.PageSize - 1, + modifyByte: 4*hostarch.PageSize - 1, dataAndTreeInSameFile: true, }, { name: "AfterRangeSeparateFile", - modifyByte: 5 * usermem.PageSize, + modifyByte: 5 * hostarch.PageSize, dataAndTreeInSameFile: false, }, { name: "AfterRangeSameFile", - modifyByte: 5 * usermem.PageSize, + modifyByte: 5 * hostarch.PageSize, dataAndTreeInSameFile: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - dataSize := int64(8 * usermem.PageSize) - verifyStart := int64(4 * usermem.PageSize) - verifySize := int64(usermem.PageSize) + dataSize := int64(8 * hostarch.PageSize) + verifyStart := int64(4 * hostarch.PageSize) + verifySize := int64(hostarch.PageSize) var buf bytes.Buffer // Modified byte is outside verify range. Verify should succeed. data, params := prepareVerify(t, dataSize, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, verifyStart, verifySize, &buf) @@ -712,16 +713,16 @@ func TestModifyInsideVerifyRange(t *testing.T) { // to fail. { name: "BlockAlignedRangeSeparateFile", - verifyStart: 4 * usermem.PageSize, - verifySize: usermem.PageSize, - modifyByte: 4 * usermem.PageSize, + verifyStart: 4 * hostarch.PageSize, + verifySize: hostarch.PageSize, + modifyByte: 4 * hostarch.PageSize, dataAndTreeInSameFile: false, }, { name: "BlockAlignedRangeSameFile", - verifyStart: 4 * usermem.PageSize, - verifySize: usermem.PageSize, - modifyByte: 4 * usermem.PageSize, + verifyStart: 4 * hostarch.PageSize, + verifySize: hostarch.PageSize, + modifyByte: 4 * hostarch.PageSize, dataAndTreeInSameFile: true, }, // The tests below use a non-block-aligned verify range. @@ -729,48 +730,48 @@ func TestModifyInsideVerifyRange(t *testing.T) { // verify to fail. { name: "ModifyStartSeparateFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 4*usermem.PageSize + 123, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 4*hostarch.PageSize + 123, dataAndTreeInSameFile: false, }, { name: "ModifyStartSameFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 4*usermem.PageSize + 123, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 4*hostarch.PageSize + 123, dataAndTreeInSameFile: true, }, // Modifying a byte at the end of verify range should cause // verify to fail. { name: "ModifyEndSeparateFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 6*usermem.PageSize + 123, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 6*hostarch.PageSize + 123, dataAndTreeInSameFile: false, }, { name: "ModifyEndSameFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 6*usermem.PageSize + 123, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 6*hostarch.PageSize + 123, dataAndTreeInSameFile: true, }, // Modifying a byte in the middle verified block should cause // verify to fail. { name: "ModifyMiddleSeparateFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 5*usermem.PageSize + 123, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 5*hostarch.PageSize + 123, dataAndTreeInSameFile: false, }, { name: "ModifyMiddleSameFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 5*usermem.PageSize + 123, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 5*hostarch.PageSize + 123, dataAndTreeInSameFile: true, }, // Modifying a byte in the first block in the verified range @@ -778,16 +779,16 @@ func TestModifyInsideVerifyRange(t *testing.T) { // out of verify range. { name: "ModifyFirstBlockSeparateFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 4*usermem.PageSize + 122, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 4*hostarch.PageSize + 122, dataAndTreeInSameFile: false, }, { name: "ModifyFirstBlockSameFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 4*usermem.PageSize + 122, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 4*hostarch.PageSize + 122, dataAndTreeInSameFile: true, }, // Modifying a byte in the last block in the verified range @@ -795,22 +796,22 @@ func TestModifyInsideVerifyRange(t *testing.T) { // out of verify range. { name: "ModifyLastBlockSeparateFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 6*usermem.PageSize + 124, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 6*hostarch.PageSize + 124, dataAndTreeInSameFile: false, }, { name: "ModifyLastBlockSameFile", - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 6*usermem.PageSize + 124, + verifyStart: 4*hostarch.PageSize + 123, + verifySize: 2 * hostarch.PageSize, + modifyByte: 6*hostarch.PageSize + 124, dataAndTreeInSameFile: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - dataSize := int64(8 * usermem.PageSize) + dataSize := int64(8 * hostarch.PageSize) var buf bytes.Buffer data, params := prepareVerify(t, dataSize, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, tc.verifyStart, tc.verifySize, &buf) // Flip a bit in data and checks Verify results. @@ -854,7 +855,7 @@ func TestVerifyRandom(t *testing.T) { rand.Seed(time.Now().UnixNano()) // Use a random dataSize. Minimum size 2 so that we can pick a random // portion from it. - dataSize := rand.Int63n(200*usermem.PageSize) + 2 + dataSize := rand.Int63n(200*hostarch.PageSize) + 2 // Pick a random portion of data. start := rand.Int63n(dataSize - 1) diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index c9f9357de..2a2f0d611 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -38,7 +38,7 @@ var ( ) // Uint64Metric encapsulates a uint64 that represents some kind of metric to be -// monitored. +// monitored. We currently support metrics with at most one field. // // Metrics are not saved across save/restore and thus reset to zero on restore. // @@ -46,6 +46,16 @@ var ( type Uint64Metric struct { // value is the actual value of the metric. It must be accessed atomically. value uint64 + + // numFields is the number of metric fields. It is immutable once + // initialized. + numFields int + + // mu protects the below fields. + mu sync.RWMutex `state:"nosave"` + + // fields is the map of fields in the metric. + fields map[string]uint64 } var ( @@ -97,8 +107,19 @@ type customUint64Metric struct { // metadata describes the metric. It is immutable. metadata *pb.MetricMetadata - // value returns the current value of the metric. - value func() uint64 + // value returns the current value of the metric for the given set of + // fields. It takes a variadic number of field values as argument. + value func(fieldValues ...string) uint64 +} + +// Field contains the field name and allowed values for the metric which is +// used in registration of the metric. +type Field struct { + // name is the metric field name. + name string + + // allowedValues is the list of allowed values for the field. + allowedValues []string } // RegisterCustomUint64Metric registers a metric with the given name. @@ -109,7 +130,8 @@ type customUint64Metric struct { // Preconditions: // * name must be globally unique. // * Initialize/Disable have not been called. -func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error { +// * value is expected to accept exactly len(fields) arguments. +func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func(...string) uint64, fields ...Field) error { if initialized { return ErrInitializationDone } @@ -129,13 +151,25 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met }, value: value, } + + // Metrics can exist without fields. + if len(fields) > 1 { + panic("Sentry metrics support at most one field") + } + + for _, field := range fields { + allMetrics.m[name].metadata.Fields = append(allMetrics.m[name].metadata.Fields, &pb.MetricMetadata_Field{ + FieldName: field.name, + AllowedValues: field.allowedValues, + }) + } return nil } -// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric and panics -// if it returns an error. -func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func() uint64) { - if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value); err != nil { +// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric for metrics +// without fields and panics if it returns an error. +func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) { + if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil { panic(fmt.Sprintf("Unable to register metric %q: %v", name, err)) } } @@ -144,15 +178,24 @@ func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, descript // name. // // Metrics must be statically defined (i.e., at init). -func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string) (*Uint64Metric, error) { - var m Uint64Metric - return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value) +func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string, fields ...Field) (*Uint64Metric, error) { + m := Uint64Metric{ + numFields: len(fields), + } + + if m.numFields == 1 { + m.fields = make(map[string]uint64) + for _, fieldValue := range fields[0].allowedValues { + m.fields[fieldValue] = 0 + } + } + return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value, fields...) } // MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an // error. -func MustCreateNewUint64Metric(name string, sync bool, description string) *Uint64Metric { - m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description) +func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric { + m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...) if err != nil { panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) } @@ -169,19 +212,56 @@ func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description st return m } -// Value returns the current value of the metric. -func (m *Uint64Metric) Value() uint64 { - return atomic.LoadUint64(&m.value) +// Value returns the current value of the metric for the given set of fields. +func (m *Uint64Metric) Value(fieldValues ...string) uint64 { + if m.numFields != len(fieldValues) { + panic(fmt.Sprintf("Number of fieldValues %d is not equal to the number of metric fields %d", len(fieldValues), m.numFields)) + } + + switch m.numFields { + case 0: + return atomic.LoadUint64(&m.value) + case 1: + m.mu.RLock() + defer m.mu.RUnlock() + + fieldValue := fieldValues[0] + if _, ok := m.fields[fieldValue]; !ok { + panic(fmt.Sprintf("Metric does not allow to have field value %s", fieldValue)) + } + return m.fields[fieldValue] + default: + panic("Sentry metrics do not support more than one field") + } } -// Increment increments the metric by 1. -func (m *Uint64Metric) Increment() { - atomic.AddUint64(&m.value, 1) +// Increment increments the metric field by 1. +func (m *Uint64Metric) Increment(fieldValues ...string) { + m.IncrementBy(1, fieldValues...) } // IncrementBy increments the metric by v. -func (m *Uint64Metric) IncrementBy(v uint64) { - atomic.AddUint64(&m.value, v) +func (m *Uint64Metric) IncrementBy(v uint64, fieldValues ...string) { + if m.numFields != len(fieldValues) { + panic(fmt.Sprintf("Number of fieldValues %d is not equal to the number of metric fields %d", len(fieldValues), m.numFields)) + } + + switch m.numFields { + case 0: + atomic.AddUint64(&m.value, v) + return + case 1: + fieldValue := fieldValues[0] + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.fields[fieldValue]; !ok { + panic(fmt.Sprintf("Metric does not allow to have field value %s", fieldValue)) + } + m.fields[fieldValue] += v + default: + panic("Sentry metrics do not support more than one field") + } } // metricSet holds named metrics. @@ -199,14 +279,30 @@ func makeMetricSet() metricSet { // Values returns a snapshot of all values in m. func (m *metricSet) Values() metricValues { vals := make(metricValues) + for k, v := range m.m { - vals[k] = v.value() + fields := v.metadata.GetFields() + switch len(fields) { + case 0: + vals[k] = v.value() + case 1: + values := fields[0].GetAllowedValues() + fieldsMap := make(map[string]uint64) + for _, fieldValue := range values { + fieldsMap[fieldValue] = v.value(fieldValue) + } + vals[k] = fieldsMap + default: + panic(fmt.Sprintf("Unsupported number of metric fields: %d", len(fields))) + } } return vals } -// metricValues contains a copy of the values of all metrics. -type metricValues map[string]uint64 +// metricValues contains a copy of the values of all metrics. It is a map +// with key as metric name and value can be either uint64 or map[string]uint64 +// to support metrics with one field. +type metricValues map[string]interface{} var ( // emitMu protects metricsAtLastEmit and ensures that all emitted @@ -233,14 +329,37 @@ func EmitMetricUpdate() { snapshot := allMetrics.Values() m := pb.MetricUpdate{} + // On the first call metricsAtLastEmit will be empty. Include all + // metrics then. for k, v := range snapshot { - // On the first call metricsAtLastEmit will be empty. Include - // all metrics then. - if prev, ok := metricsAtLastEmit[k]; !ok || prev != v { + prev, ok := metricsAtLastEmit[k] + switch t := v.(type) { + case uint64: + // Metric exists and value did not change. + if ok && prev.(uint64) == t { + continue + } + m.Metrics = append(m.Metrics, &pb.MetricValue{ Name: k, - Value: &pb.MetricValue_Uint64Value{v}, + Value: &pb.MetricValue_Uint64Value{t}, }) + case map[string]uint64: + for fieldValue, metricValue := range t { + // Emit data on the first call only if the field + // value has been incremented. For all other + // calls, emit data if the field value has been + // changed from the previous emit. + if (!ok && metricValue == 0) || (ok && prev.(map[string]uint64)[fieldValue] == metricValue) { + continue + } + + m.Metrics = append(m.Metrics, &pb.MetricValue{ + Name: k, + FieldValues: []string{fieldValue}, + Value: &pb.MetricValue_Uint64Value{metricValue}, + }) + } } } diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto index 3cc89047d..53c8b4b50 100644 --- a/pkg/metric/metric.proto +++ b/pkg/metric/metric.proto @@ -48,6 +48,15 @@ message MetricMetadata { // units is the units of the metric value. Units units = 6; + + message Field { + string field_name = 1; + repeated string allowed_values = 2; + } + + // fields contains the metric fields. Currently a metric can have at most + // one field. + repeated Field fields = 7; } // MetricRegistration contains the metadata for all metrics that will be in @@ -66,6 +75,8 @@ message MetricValue { oneof value { uint64 uint64_value = 2; } + + repeated string field_values = 4; } // MetricUpdate contains new values for multiple distinct metrics. diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go index aefd0ea5c..c71dfd460 100644 --- a/pkg/metric/metric_test.go +++ b/pkg/metric/metric_test.go @@ -59,8 +59,9 @@ func reset() { } const ( - fooDescription = "Foo!" - barDescription = "Bar Baz" + fooDescription = "Foo!" + barDescription = "Bar Baz" + counterDescription = "Counter" ) func TestInitialize(t *testing.T) { @@ -95,7 +96,7 @@ func TestInitialize(t *testing.T) { foundBar := false for _, m := range mr.Metrics { if m.Type != pb.MetricMetadata_TYPE_UINT64 { - t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_TYPE_UINT64) + t.Errorf("Metadata %+v Type got %v want pb.MetricMetadata_TYPE_UINT64", m, m.Type) } if !m.Cumulative { t.Errorf("Metadata %+v Cumulative got false want true", m) @@ -256,3 +257,88 @@ func TestEmitMetricUpdate(t *testing.T) { t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value) } } + +func TestEmitMetricUpdateWithFields(t *testing.T) { + defer reset() + + field := Field{ + name: "weirdness_type", + allowedValues: []string{"weird1", "weird2"}} + + counter, err := NewUint64Metric("/weirdness", false, pb.MetricMetadata_UNITS_NONE, counterDescription, field) + if err != nil { + t.Fatalf("NewUint64Metric got err %v want nil", err) + } + + Initialize() + + // Don't care about the registration metrics. + emitter.Reset() + EmitMetricUpdate() + + // For metrics with fields, we do not emit data unless the value is + // incremented. + if len(emitter) != 0 { + t.Fatalf("EmitMetricUpdate emitted %d events want 0", len(emitter)) + } + + counter.IncrementBy(4, "weird1") + counter.Increment("weird2") + + emitter.Reset() + EmitMetricUpdate() + + if len(emitter) != 1 { + t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) + } + + update, ok := emitter[0].(*pb.MetricUpdate) + if !ok { + t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) + } + + if len(update.Metrics) != 2 { + t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics)) + } + + foundWeird1 := false + foundWeird2 := false + for i := 0; i < len(update.Metrics); i++ { + m := update.Metrics[i] + + if m.Name != "/weirdness" { + t.Errorf("Metric %+v name got %q want '/weirdness'", m, m.Name) + } + if len(m.FieldValues) != 1 { + t.Errorf("MetricUpdate got %d fields want 1", len(m.FieldValues)) + } + + switch m.FieldValues[0] { + case "weird1": + uv, ok := m.Value.(*pb.MetricValue_Uint64Value) + if !ok { + t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) + } + if uv.Uint64Value != 4 { + t.Errorf("%v: Value got %v want 4", m, uv.Uint64Value) + } + foundWeird1 = true + case "weird2": + uv, ok := m.Value.(*pb.MetricValue_Uint64Value) + if !ok { + t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) + } + if uv.Uint64Value != 1 { + t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value) + } + foundWeird2 = true + } + } + + if !foundWeird1 { + t.Errorf("Field value weird1 not found: %+v", emitter) + } + if !foundWeird2 { + t.Errorf("Field value weird2 not found: %+v", emitter) + } +} diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go index 0472eca3f..fb8984dd6 100644 --- a/pkg/refsvfs2/refs_map.go +++ b/pkg/refsvfs2/refs_map.go @@ -112,20 +112,27 @@ func logEvent(obj CheckedObject, msg string) { log.Infof("[%s %p] %s:\n%s", obj.RefType(), obj, msg, refs_vfs1.FormatStack(refs_vfs1.RecordStack())) } +// checkOnce makes sure that leak checking is only done once. DoLeakCheck is +// called from multiple places (which may overlap) to cover different sandbox +// exit scenarios. +var checkOnce sync.Once + // DoLeakCheck iterates through the live object map and logs a message for each // object. It is called once no reference-counted objects should be reachable // anymore, at which point anything left in the map is considered a leak. func DoLeakCheck() { if leakCheckEnabled() { - liveObjectsMu.Lock() - defer liveObjectsMu.Unlock() - leaked := len(liveObjects) - if leaked > 0 { - msg := fmt.Sprintf("Leak checking detected %d leaked objects:\n", leaked) - for obj := range liveObjects { - msg += obj.LeakMessage() + "\n" + checkOnce.Do(func() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + leaked := len(liveObjects) + if leaked > 0 { + msg := fmt.Sprintf("Leak checking detected %d leaked objects:\n", leaked) + for obj := range liveObjects { + msg += obj.LeakMessage() + "\n" + } + log.Warningf(msg) } - log.Warningf(msg) - } + }) } } diff --git a/pkg/ring0/BUILD b/pkg/ring0/BUILD index 885958456..fda6ba601 100644 --- a/pkg/ring0/BUILD +++ b/pkg/ring0/BUILD @@ -77,10 +77,10 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/cpuid", + "//pkg/hostarch", "//pkg/ring0/pagetables", "//pkg/safecopy", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", - "//pkg/usermem", ], ) diff --git a/pkg/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go index ceddf719d..76776c65c 100644 --- a/pkg/ring0/defs_amd64.go +++ b/pkg/ring0/defs_amd64.go @@ -17,7 +17,7 @@ package ring0 import ( - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) var ( @@ -25,7 +25,7 @@ var ( UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(hostarch.PageSize-1) // KernelStartAddress is the starting kernel address. KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) diff --git a/pkg/ring0/defs_arm64.go b/pkg/ring0/defs_arm64.go index c372b02bb..0125690d2 100644 --- a/pkg/ring0/defs_arm64.go +++ b/pkg/ring0/defs_arm64.go @@ -17,7 +17,7 @@ package ring0 import ( - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) var ( @@ -25,7 +25,7 @@ var ( UserspaceSize = uintptr(1) << (VirtualAddressBits()) // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(hostarch.PageSize-1) // KernelStartAddress is the starting kernel address. KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) diff --git a/pkg/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD index f421e1687..9ea8f9a4f 100644 --- a/pkg/ring0/gen_offsets/BUILD +++ b/pkg/ring0/gen_offsets/BUILD @@ -33,9 +33,9 @@ go_binary( ], deps = [ "//pkg/cpuid", + "//pkg/hostarch", "//pkg/ring0/pagetables", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", - "//pkg/usermem", ], ) diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 33c259757..41dfd0bf9 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -20,7 +20,7 @@ import ( "encoding/binary" "reflect" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // init initializes architecture-specific state. @@ -34,7 +34,7 @@ func (k *Kernel) init(maxCPUs int) { entries = make([]kernelEntry, maxCPUs+padding-1) totalSize := entrySize * uintptr(maxCPUs+padding-1) addr := reflect.ValueOf(&entries[0]).Pointer() - if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize { + if addr&(hostarch.PageSize-1) == 0 && totalSize >= hostarch.PageSize { // The runtime forces power-of-2 alignment for allocations, and we are therefore // safe once the first address is aligned and the chunk is at least a full page. break @@ -44,10 +44,10 @@ func (k *Kernel) init(maxCPUs int) { k.cpuEntries = entries k.globalIDT = &idt64{} - if reflect.TypeOf(idt64{}).Size() != usermem.PageSize { + if reflect.TypeOf(idt64{}).Size() != hostarch.PageSize { panic("Size of globalIDT should be PageSize") } - if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 { + if reflect.ValueOf(k.globalIDT).Pointer()&(hostarch.PageSize-1) != 0 { panic("Allocated globalIDT should be page aligned") } @@ -71,13 +71,13 @@ func (k *Kernel) EntryRegions() map[uintptr]uintptr { addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer() size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries)) - end, _ := usermem.Addr(addr + size).RoundUp() - regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + end, _ := hostarch.Addr(addr + size).RoundUp() + regions[uintptr(hostarch.Addr(addr).RoundDown())] = uintptr(end) addr = reflect.ValueOf(k.globalIDT).Pointer() size = reflect.TypeOf(idt64{}).Size() - end, _ = usermem.Addr(addr + size).RoundUp() - regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + end, _ = hostarch.Addr(addr + size).RoundUp() + regions[uintptr(hostarch.Addr(addr).RoundDown())] = uintptr(end) return regions } @@ -250,6 +250,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { } SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point. WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. + RestoreKernelFPState() // escapes: no. Restore kernel MXCSR. return } @@ -321,3 +322,21 @@ func SetCPUIDFaulting(on bool) bool { func ReadCR2() uintptr { return readCR2() } + +// kernelMXCSR is the value of the mxcsr register in the Sentry. +// +// The MXCSR control configuration is initialized once and never changed. Look +// at src/cmd/compile/abi-internal.md in the golang sources for more details. +var kernelMXCSR uint32 + +// RestoreKernelFPState restores the Sentry floating point state. +// +//go:nosplit +func RestoreKernelFPState() { + // Restore the MXCSR control configuration. + ldmxcsr(&kernelMXCSR) +} + +func init() { + stmxcsr(&kernelMXCSR) +} diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go index 7975e5f92..21db910a2 100644 --- a/pkg/ring0/kernel_arm64.go +++ b/pkg/ring0/kernel_arm64.go @@ -65,7 +65,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer()) if switchOpts.Flush { - FlushTlbByASID(uintptr(switchOpts.UserASID)) + LocalFlushTlbByASID(uintptr(switchOpts.UserASID)) } regs := switchOpts.Registers @@ -89,3 +89,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { return } + +// RestoreKernelFPState restores the Sentry floating point state. +// +//go:nosplit +func RestoreKernelFPState() { +} diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go index 0ec5c3bc5..3e6bb9663 100644 --- a/pkg/ring0/lib_amd64.go +++ b/pkg/ring0/lib_amd64.go @@ -61,6 +61,12 @@ func wrgsbase(addr uintptr) // wrgsmsr writes to the GS_BASE MSR. func wrgsmsr(addr uintptr) +// stmxcsr reads the MXCSR control and status register. +func stmxcsr(addr *uint32) + +// ldmxcsr writes to the MXCSR control and status register. +func ldmxcsr(addr *uint32) + // readCR2 reads the current CR2 value. func readCR2() uintptr diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s index 2fe83568a..70a43e79e 100644 --- a/pkg/ring0/lib_amd64.s +++ b/pkg/ring0/lib_amd64.s @@ -198,3 +198,15 @@ TEXT ·rdmsr(SB),NOSPLIT,$0-16 MOVL AX, ret+8(FP) MOVL DX, ret+12(FP) RET + +// stmxcsr reads the MXCSR control and status register. +TEXT ·stmxcsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), SI + STMXCSR (SI) + RET + +// ldmxcsr writes to the MXCSR control and status register. +TEXT ·ldmxcsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), SI + LDMXCSR (SI) + RET diff --git a/pkg/ring0/lib_arm64.go b/pkg/ring0/lib_arm64.go index edf24eda3..5eabd4296 100644 --- a/pkg/ring0/lib_arm64.go +++ b/pkg/ring0/lib_arm64.go @@ -31,6 +31,9 @@ func FlushTlbByVA(addr uintptr) // FlushTlbByASID invalidates tlb by ASID/Inner-Shareable. func FlushTlbByASID(asid uintptr) +// LocalFlushTlbByASID invalidates tlb by ASID. +func LocalFlushTlbByASID(asid uintptr) + // FlushTlbAll invalidates all tlb. func FlushTlbAll() @@ -62,9 +65,10 @@ func LoadFloatingPoint(*byte) // SaveFloatingPoint saves floating point state. func SaveFloatingPoint(*byte) +// FPSIMDDisableTrap disables fpsimd. func FPSIMDDisableTrap() -// DisableVFP disables fpsimd. +// FPSIMDEnableTrap enables fpsimd. func FPSIMDEnableTrap() // Init sets function pointers based on architectural features. diff --git a/pkg/ring0/lib_arm64.s b/pkg/ring0/lib_arm64.s index e39b32841..69ebaf519 100644 --- a/pkg/ring0/lib_arm64.s +++ b/pkg/ring0/lib_arm64.s @@ -32,6 +32,14 @@ TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8 DSB $11 // dsb(ish) RET +TEXT ·LocalFlushTlbByASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + LSL $TLBI_ASID_SHIFT, R1, R1 + DSB $10 // dsb(ishst) + WORD $0xd5088741 // tlbi aside1, x1 + DSB $11 // dsb(ish) + RET + TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 DSB $6 // dsb(nshst) WORD $0xd508871f // __tlbi(vmalle1) diff --git a/pkg/ring0/pagetables/BUILD b/pkg/ring0/pagetables/BUILD index 65a978cbb..f855f4d42 100644 --- a/pkg/ring0/pagetables/BUILD +++ b/pkg/ring0/pagetables/BUILD @@ -68,8 +68,8 @@ go_library( "//pkg/sentry/platform/kvm:__subpackages__", ], deps = [ + "//pkg/hostarch", "//pkg/sync", - "//pkg/usermem", ], ) @@ -84,5 +84,5 @@ go_test( ":walker_check_arm64", ], library = ":pagetables", - deps = ["//pkg/usermem"], + deps = ["//pkg/hostarch"], ) diff --git a/pkg/ring0/pagetables/allocator_unsafe.go b/pkg/ring0/pagetables/allocator_unsafe.go index d08bfdeb3..191d0942b 100644 --- a/pkg/ring0/pagetables/allocator_unsafe.go +++ b/pkg/ring0/pagetables/allocator_unsafe.go @@ -17,23 +17,23 @@ package pagetables import ( "unsafe" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // newAlignedPTEs returns a set of aligned PTEs. func newAlignedPTEs() *PTEs { ptes := new(PTEs) - offset := physicalFor(ptes) & (usermem.PageSize - 1) + offset := physicalFor(ptes) & (hostarch.PageSize - 1) if offset == 0 { // Already aligned. return ptes } // Need to force an aligned allocation. - unaligned := make([]byte, (2*usermem.PageSize)-1) - offset = uintptr(unsafe.Pointer(&unaligned[0])) & (usermem.PageSize - 1) + unaligned := make([]byte, (2*hostarch.PageSize)-1) + offset = uintptr(unsafe.Pointer(&unaligned[0])) & (hostarch.PageSize - 1) if offset != 0 { - offset = usermem.PageSize - offset + offset = hostarch.PageSize - offset } return (*PTEs)(unsafe.Pointer(&unaligned[offset])) } diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 8c0a6aa82..3f17fba49 100644 --- a/pkg/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -21,7 +21,7 @@ package pagetables import ( - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // PageTables is a set of page tables. @@ -142,7 +142,7 @@ func (*mapVisitor) requiresSplit() bool { return true } // // +checkescape:hard,stack //go:nosplit -func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { +func (p *PageTables) Map(addr hostarch.Addr, length uintptr, opts MapOpts, physical uintptr) bool { if p.readOnlyShared { panic("Should not modify read-only shared pagetables.") } @@ -198,7 +198,7 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { // // +checkescape:hard,stack //go:nosplit -func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { +func (p *PageTables) Unmap(addr hostarch.Addr, length uintptr) bool { if p.readOnlyShared { panic("Should not modify read-only shared pagetables.") } @@ -249,7 +249,7 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { // // +checkescape:hard,stack //go:nosplit -func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool { +func (p *PageTables) IsEmpty(addr hostarch.Addr, length uintptr) bool { w := emptyWalker{ pageTables: p, } @@ -298,9 +298,9 @@ func (*lookupVisitor) requiresSplit() bool { return false } // // +checkescape:hard,stack //go:nosplit -func (p *PageTables) Lookup(addr usermem.Addr, findFirst bool) (virtual usermem.Addr, physical, size uintptr, opts MapOpts) { - mask := uintptr(usermem.PageSize - 1) - addr &^= usermem.Addr(mask) +func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarch.Addr, physical, size uintptr, opts MapOpts) { + mask := uintptr(hostarch.PageSize - 1) + addr &^= hostarch.Addr(mask) w := lookupWalker{ pageTables: p, visitor: lookupVisitor{ @@ -308,12 +308,12 @@ func (p *PageTables) Lookup(addr usermem.Addr, findFirst bool) (virtual usermem. findFirst: findFirst, }, } - end := ^usermem.Addr(0) &^ usermem.Addr(mask) + end := ^hostarch.Addr(0) &^ hostarch.Addr(mask) if !findFirst { end = addr + 1 } w.iterateRange(uintptr(addr), uintptr(end)) - return usermem.Addr(w.visitor.target), w.visitor.physical, w.visitor.size, w.visitor.opts + return hostarch.Addr(w.visitor.target), w.visitor.physical, w.visitor.size, w.visitor.opts } // MarkReadOnlyShared marks the pagetables read-only and can be shared. diff --git a/pkg/ring0/pagetables/pagetables_aarch64.go b/pkg/ring0/pagetables/pagetables_aarch64.go index 163a3aea3..86eb00a4f 100644 --- a/pkg/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/ring0/pagetables/pagetables_aarch64.go @@ -19,7 +19,7 @@ package pagetables import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // archPageTables is architecture-specific data. @@ -85,7 +85,7 @@ const ( // MapOpts are x86 options. type MapOpts struct { // AccessType defines permissions. - AccessType usermem.AccessType + AccessType hostarch.AccessType // Global indicates the page is globally accessible. Global bool @@ -120,7 +120,7 @@ func (p *PTE) Opts() MapOpts { v := atomic.LoadUintptr((*uintptr)(p)) return MapOpts{ - AccessType: usermem.AccessType{ + AccessType: hostarch.AccessType{ Read: true, Write: v&readOnly == 0, Execute: v&xn == 0, diff --git a/pkg/ring0/pagetables/pagetables_amd64_test.go b/pkg/ring0/pagetables/pagetables_amd64_test.go index 54e8e554f..a13c616ae 100644 --- a/pkg/ring0/pagetables/pagetables_amd64_test.go +++ b/pkg/ring0/pagetables/pagetables_amd64_test.go @@ -19,19 +19,19 @@ package pagetables import ( "testing" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) func Test2MAnd4K(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a small page and a huge page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*47) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) + pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: hostarch.Read}, pmdSize*47) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, + {0x00007f0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read}}, }) } @@ -39,12 +39,12 @@ func Test1GAnd4K(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a small page and a super page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*47) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) + pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: hostarch.Read}, pudSize*47) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, + {0x00007f0000000000, pudSize, pudSize * 47, MapOpts{AccessType: hostarch.Read}}, }) } @@ -52,12 +52,12 @@ func TestSplit1GPage(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a super page and knock out the middle. - pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*42) - pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize)) + pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: hostarch.Read}, pudSize*42) + pt.Unmap(hostarch.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize)) checkMappings(t, pt, []mapping{ - {0x00007f0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read}}, + {0x00007f0000000000, pteSize, pudSize * 42, MapOpts{AccessType: hostarch.Read}}, + {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: hostarch.Read}}, }) } @@ -65,11 +65,11 @@ func TestSplit2MPage(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a huge page and knock out the middle. - pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*42) - pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize)) + pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: hostarch.Read}, pmdSize*42) + pt.Unmap(hostarch.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize)) checkMappings(t, pt, []mapping{ - {0x00007f0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read}}, + {0x00007f0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: hostarch.Read}}, + {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read}}, }) } diff --git a/pkg/ring0/pagetables/pagetables_arm64_test.go b/pkg/ring0/pagetables/pagetables_arm64_test.go index 2f73d424f..2514b9ac5 100644 --- a/pkg/ring0/pagetables/pagetables_arm64_test.go +++ b/pkg/ring0/pagetables/pagetables_arm64_test.go @@ -19,24 +19,24 @@ package pagetables import ( "testing" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) func Test2MAnd4K(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a small page and a huge page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) - pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*47) - pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42) - pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47) + pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: false}, pteSize*42) + pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: false}, pmdSize*47) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, - {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, - {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}}, - {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}}, + {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: true}}, + {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: false}}, + {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: false}}, }) } @@ -44,12 +44,12 @@ func Test1GAnd4K(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a small page and a super page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) - pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*47) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, - {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}}, + {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: hostarch.Read, User: true}}, }) } @@ -57,12 +57,12 @@ func TestSplit1GPage(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a super page and knock out the middle. - pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42) - pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize)) + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*42) + pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize)) checkMappings(t, pt, []mapping{ - {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, - {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: hostarch.Read, User: true}}, + {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}}, }) } @@ -70,11 +70,11 @@ func TestSplit2MPage(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a huge page and knock out the middle. - pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42) - pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize)) + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*42) + pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize)) checkMappings(t, pt, []mapping{ - {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, - {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: hostarch.Read, User: true}}, + {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}}, }) } diff --git a/pkg/ring0/pagetables/pagetables_test.go b/pkg/ring0/pagetables/pagetables_test.go index 772f4fc5e..df93dcb6a 100644 --- a/pkg/ring0/pagetables/pagetables_test.go +++ b/pkg/ring0/pagetables/pagetables_test.go @@ -15,9 +15,8 @@ package pagetables import ( + "gvisor.dev/gvisor/pkg/hostarch" "testing" - - "gvisor.dev/gvisor/pkg/usermem" ) type mapping struct { @@ -90,7 +89,7 @@ func TestUnmap(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map and unmap one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) pt.Unmap(0x400000, pteSize) checkMappings(t, pt, nil) @@ -100,10 +99,10 @@ func TestReadOnly(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*42) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.Read}}, }) } @@ -111,10 +110,10 @@ func TestReadWrite(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, }) } @@ -122,12 +121,12 @@ func TestSerialEntries(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map two sequential entries. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x401000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*47) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) + pt.Map(0x401000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*47) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x401000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.ReadWrite}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, + {0x401000, pteSize, pteSize * 47, MapOpts{AccessType: hostarch.ReadWrite}}, }) } @@ -135,11 +134,11 @@ func TestSpanningEntries(t *testing.T) { pt := New(NewRuntimeAllocator()) // Span a pgd with two pages. - pt.Map(0x00007efffffff000, 2*pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42) + pt.Map(0x00007efffffff000, 2*pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*42) checkMappings(t, pt, []mapping{ - {0x00007efffffff000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000, pteSize, pteSize * 43, MapOpts{AccessType: usermem.Read}}, + {0x00007efffffff000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.Read}}, + {0x00007f0000000000, pteSize, pteSize * 43, MapOpts{AccessType: hostarch.Read}}, }) } @@ -147,11 +146,11 @@ func TestSparseEntries(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map two entries in different pgds. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*47) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) + pt.Map(0x00007f0000000000, pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*47) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.Read}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, + {0x00007f0000000000, pteSize, pteSize * 47, MapOpts{AccessType: hostarch.Read}}, }) } diff --git a/pkg/ring0/pagetables/pagetables_x86.go b/pkg/ring0/pagetables/pagetables_x86.go index 32edd2f0a..e43698173 100644 --- a/pkg/ring0/pagetables/pagetables_x86.go +++ b/pkg/ring0/pagetables/pagetables_x86.go @@ -19,7 +19,7 @@ package pagetables import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // archPageTables is architecture-specific data. @@ -63,7 +63,7 @@ const ( // MapOpts are x86 options. type MapOpts struct { // AccessType defines permissions. - AccessType usermem.AccessType + AccessType hostarch.AccessType // Global indicates the page is globally accessible. Global bool @@ -97,7 +97,7 @@ func (p *PTE) Valid() bool { func (p *PTE) Opts() MapOpts { v := atomic.LoadUintptr((*uintptr)(p)) return MapOpts{ - AccessType: usermem.AccessType{ + AccessType: hostarch.AccessType{ Read: v&present != 0, Write: v&writable != 0, Execute: v&executeDisable == 0, diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s index a0cd78f33..290579e53 100644 --- a/pkg/safecopy/atomic_amd64.s +++ b/pkg/safecopy/atomic_amd64.s @@ -44,6 +44,12 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24 MOVL AX, old+16(FP) RET +// func addrOfSwapUint32() uintptr +TEXT ·addrOfSwapUint32(SB), $0-8 + MOVQ $·swapUint32(SB), AX + MOVQ AX, ret+0(FP) + RET + // handleSwapUint64Fault returns the value stored in DI. Control is transferred // to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal // number stored in DI. @@ -74,6 +80,12 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28 MOVQ AX, old+16(FP) RET +// func addrOfSwapUint64() uintptr +TEXT ·addrOfSwapUint64(SB), $0-8 + MOVQ $·swapUint64(SB), AX + MOVQ AX, ret+0(FP) + RET + // handleCompareAndSwapUint32Fault returns the value stored in DI. Control is // transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the // signal number stored in DI. @@ -107,6 +119,12 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 MOVL AX, prev+16(FP) RET +// func addrOfCompareAndSwapUint32() uintptr +TEXT ·addrOfCompareAndSwapUint32(SB), $0-8 + MOVQ $·compareAndSwapUint32(SB), AX + MOVQ AX, ret+0(FP) + RET + // handleLoadUint32Fault returns the value stored in DI. Control is transferred // to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal // number stored in DI. @@ -134,3 +152,9 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16 MOVL (AX), BX MOVL BX, val+8(FP) RET + +// func addrOfLoadUint32() uintptr +TEXT ·addrOfLoadUint32(SB), $0-8 + MOVQ $·loadUint32(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s index d58ed71f7..55c031a3c 100644 --- a/pkg/safecopy/atomic_arm64.s +++ b/pkg/safecopy/atomic_arm64.s @@ -33,6 +33,12 @@ again: MOVW R2, old+16(FP) RET +// func addrOfSwapUint32() uintptr +TEXT ·addrOfSwapUint32(SB), $0-8 + MOVD $·swapUint32(SB), R0 + MOVD R0, ret+0(FP) + RET + // handleSwapUint64Fault returns the value stored in R1. Control is transferred // to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal // number stored in R1. @@ -62,6 +68,12 @@ again: MOVD R2, old+16(FP) RET +// func addrOfSwapUint64() uintptr +TEXT ·addrOfSwapUint64(SB), $0-8 + MOVD $·swapUint64(SB), R0 + MOVD R0, ret+0(FP) + RET + // handleCompareAndSwapUint32Fault returns the value stored in R1. Control is // transferred to it when compareAndSwapUint32 below receives SIGSEGV or SIGBUS, // with the signal number stored in R1. @@ -97,6 +109,12 @@ done: MOVW R3, prev+16(FP) RET +// func addrOfCompareAndSwapUint32() uintptr +TEXT ·addrOfCompareAndSwapUint32(SB), $0-8 + MOVD $·compareAndSwapUint32(SB), R0 + MOVD R0, ret+0(FP) + RET + // handleLoadUint32Fault returns the value stored in DI. Control is transferred // to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal // number stored in DI. @@ -124,3 +142,9 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16 LDARW (R0), R1 MOVW R1, val+8(FP) RET + +// func addrOfLoadUint32() uintptr +TEXT ·addrOfLoadUint32(SB), $0-8 + MOVD $·loadUint32(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/safecopy/memclr_amd64.s b/pkg/safecopy/memclr_amd64.s index 64cf32f05..4abaecaff 100644 --- a/pkg/safecopy/memclr_amd64.s +++ b/pkg/safecopy/memclr_amd64.s @@ -145,3 +145,9 @@ _129through256: MOVOU X0, -32(DI)(BX*1) MOVOU X0, -16(DI)(BX*1) RET + +// func addrOfMemclr() uintptr +TEXT ·addrOfMemclr(SB), $0-8 + MOVQ $·memclr(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/safecopy/memclr_arm64.s b/pkg/safecopy/memclr_arm64.s index 7361b9067..c789bfeb3 100644 --- a/pkg/safecopy/memclr_arm64.s +++ b/pkg/safecopy/memclr_arm64.s @@ -72,3 +72,9 @@ head_loop: CMP $16, R1 BLT tail_zero B aligned_to_16 + +// func addrOfMemclr() uintptr +TEXT ·addrOfMemclr(SB), $0-8 + MOVD $·memclr(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s index 00b46c18f..1d63ca1fd 100644 --- a/pkg/safecopy/memcpy_amd64.s +++ b/pkg/safecopy/memcpy_amd64.s @@ -217,3 +217,9 @@ move_129through256: MOVOU -16(SI)(BX*1), X15 MOVOU X15, -16(DI)(BX*1) RET + +// func addrOfMemcpy() uintptr +TEXT ·addrOfMemcpy(SB), $0-8 + MOVQ $·memcpy(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s index e7e541565..7b3f50aa5 100644 --- a/pkg/safecopy/memcpy_arm64.s +++ b/pkg/safecopy/memcpy_arm64.s @@ -76,3 +76,9 @@ forwardtailloop: CMP R3, R9 BNE forwardtailloop RET + +// func addrOfMemcpy() uintptr +TEXT ·addrOfMemcpy(SB), $0-8 + MOVD $·memcpy(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go index 1e0af5889..df63dd5f1 100644 --- a/pkg/safecopy/safecopy.go +++ b/pkg/safecopy/safecopy.go @@ -18,7 +18,6 @@ package safecopy import ( "fmt" - "reflect" "runtime" "golang.org/x/sys/unix" @@ -91,6 +90,11 @@ var ( // signals. func signalHandler() +// addrOfSignalHandler returns the start address of signalHandler. +// +// See comment on addrOfMemcpy for more details. +func addrOfSignalHandler() uintptr + // FindEndAddress returns the end address (one byte beyond the last) of the // function that contains the specified address (begin). func FindEndAddress(begin uintptr) uintptr { @@ -111,26 +115,26 @@ func initializeAddresses() { // The following functions are written in assembly language, so they won't // be inlined by the existing compiler/linker. Tests will fail if this // assumption is violated. - memcpyBegin = reflect.ValueOf(memcpy).Pointer() + memcpyBegin = addrOfMemcpy() memcpyEnd = FindEndAddress(memcpyBegin) - memclrBegin = reflect.ValueOf(memclr).Pointer() + memclrBegin = addrOfMemclr() memclrEnd = FindEndAddress(memclrBegin) - swapUint32Begin = reflect.ValueOf(swapUint32).Pointer() + swapUint32Begin = addrOfSwapUint32() swapUint32End = FindEndAddress(swapUint32Begin) - swapUint64Begin = reflect.ValueOf(swapUint64).Pointer() + swapUint64Begin = addrOfSwapUint64() swapUint64End = FindEndAddress(swapUint64Begin) - compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer() + compareAndSwapUint32Begin = addrOfCompareAndSwapUint32() compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin) - loadUint32Begin = reflect.ValueOf(loadUint32).Pointer() + loadUint32Begin = addrOfLoadUint32() loadUint32End = FindEndAddress(loadUint32Begin) } func init() { initializeAddresses() - if err := ReplaceSignalHandler(unix.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil { + if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) } - if err := ReplaceSignalHandler(unix.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil { + if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) } syserror.AddErrorUnwrapper(func(e error) (unix.Errno, bool) { diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go index d2ce8ff86..55743e69c 100644 --- a/pkg/safecopy/safecopy_test.go +++ b/pkg/safecopy/safecopy_test.go @@ -19,15 +19,13 @@ import ( "fmt" "io/ioutil" "math/rand" - "os" - "runtime/debug" "testing" "unsafe" "golang.org/x/sys/unix" ) -// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular +// Size of a page in bytes. Cloned from hostarch.PageSize to avoid a circular // dependency. const pageSize = 4096 @@ -568,63 +566,3 @@ func TestCompareAndSwapUint32BusError(t *testing.T) { } }) } - -func testCopy(dst, src []byte) (panicked bool) { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - debug.SetPanicOnFault(true) - copy(dst, src) - return -} - -func TestSegVOnMemmove(t *testing.T) { - // Test that SIGSEGVs received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer unix.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} - -func TestSigbusOnMemmove(t *testing.T) { - // Test that SIGBUS received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - os.Remove(f.Name()) - defer f.Close() - - a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer unix.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index a075cf88e..efbc2ddc1 100644 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go @@ -89,6 +89,18 @@ func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig //go:noescape func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) +// Return the start address of the functions above. +// +// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal +// wrapper function rather than the function itself. We must reference from +// assembly to get the ABI0 (i.e., primary) address. +func addrOfMemcpy() uintptr +func addrOfMemclr() uintptr +func addrOfSwapUint32() uintptr +func addrOfSwapUint64() uintptr +func addrOfCompareAndSwapUint32() uintptr +func addrOfLoadUint32() uintptr + // CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes // copied and an error if SIGSEGV or SIGBUS is received while reading from src. func CopyIn(dst []byte, src unsafe.Pointer) (int, error) { diff --git a/pkg/safecopy/sighandler_amd64.s b/pkg/safecopy/sighandler_amd64.s index 475ae48e9..0b5e8df66 100644 --- a/pkg/safecopy/sighandler_amd64.s +++ b/pkg/safecopy/sighandler_amd64.s @@ -131,3 +131,9 @@ handle_fault: MOVL DI, REG_RDI(DX) RET + +// func addrOfSignalHandler() uintptr +TEXT ·addrOfSignalHandler(SB), $0-8 + MOVQ $·signalHandler(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/safecopy/sighandler_arm64.s b/pkg/safecopy/sighandler_arm64.s index 53e4ac2c1..41ed70ff9 100644 --- a/pkg/safecopy/sighandler_arm64.s +++ b/pkg/safecopy/sighandler_arm64.s @@ -141,3 +141,9 @@ handle_fault: MOVW R0, REG_R1(R2) RET + +// func addrOfSignalHandler() uintptr +TEXT ·addrOfSignalHandler(SB), $0-8 + MOVD $·signalHandler(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD index 3fda3a9cc..2c7cc8769 100644 --- a/pkg/safemem/BUILD +++ b/pkg/safemem/BUILD @@ -14,6 +14,7 @@ go_library( deps = [ "//pkg/gohacks", "//pkg/safecopy", + "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go index 93879bb4f..4af534385 100644 --- a/pkg/safemem/block_unsafe.go +++ b/pkg/safemem/block_unsafe.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/sync" ) // A Block is a range of contiguous bytes, similar to []byte but with the @@ -223,8 +224,22 @@ func Copy(dst, src Block) (int, error) { func Zero(dst Block) (int, error) { if !dst.needSafecopy { bs := dst.ToSlice() - for i := range bs { - bs[i] = 0 + if !sync.RaceEnabled { + // If the race detector isn't enabled, the golang + // compiler replaces the next loop with memclr + // (https://github.com/golang/go/issues/5373). + for i := range bs { + bs[i] = 0 + } + } else { + bsLen := len(bs) + if bsLen == 0 { + return 0, nil + } + bs[0] = 0 + for i := 1; i < bsLen; i *= 2 { + copy(bs[i:], bs[:i]) + } } return len(bs), nil } diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index 201dd072f..169fc0ac3 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -54,6 +54,6 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/bpf", - "//pkg/usermem", + "//pkg/hostarch", ], ) diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index db06d1f1b..68feddf31 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -29,7 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // newVictim makes a victim binary. @@ -57,7 +57,7 @@ func dataAsInput(d *linux.SeccompData) bpf.Input { d.MarshalUnsafe(buf) return bpf.InputBytes{ Data: buf, - Order: usermem.ByteOrder, + Order: hostarch.ByteOrder, } } diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index f660f1614..c9c52530d 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/cpuid", + "//pkg/hostarch", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 921151137..290863ee6 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -22,11 +22,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/usermem" ) // Arch describes an architecture. @@ -188,11 +188,11 @@ type Context interface { // returned layout must be no lower than min, and MaxAddr for the returned // layout must be no higher than max. Repeated calls to NewMmapLayout may // return different layouts. - NewMmapLayout(min, max usermem.Addr, limits *limits.LimitSet) (MmapLayout, error) + NewMmapLayout(min, max hostarch.Addr, limits *limits.LimitSet) (MmapLayout, error) // PIELoadAddress returns a preferred load address for a // position-independent executable within l. - PIELoadAddress(l MmapLayout) usermem.Addr + PIELoadAddress(l MmapLayout) hostarch.Addr // FeatureSet returns the FeatureSet in use in this context. FeatureSet() *cpuid.FeatureSet @@ -257,18 +257,18 @@ const ( // +stateify savable type MmapLayout struct { // MinAddr is the lowest mappable address. - MinAddr usermem.Addr + MinAddr hostarch.Addr // MaxAddr is the highest mappable address. - MaxAddr usermem.Addr + MaxAddr hostarch.Addr // BottomUpBase is the lowest address that may be returned for a // MmapBottomUp mmap. - BottomUpBase usermem.Addr + BottomUpBase hostarch.Addr // TopDownBase is the highest address that may be returned for a // MmapTopDown mmap. - TopDownBase usermem.Addr + TopDownBase hostarch.Addr // DefaultDirection is the direction for most non-fixed mmaps in this // layout. @@ -316,9 +316,9 @@ type SyscallArgument struct { // SyscallArguments represents the set of arguments passed to a syscall. type SyscallArguments [6]SyscallArgument -// Pointer returns the usermem.Addr representation of a pointer argument. -func (a SyscallArgument) Pointer() usermem.Addr { - return usermem.Addr(a.Value) +// Pointer returns the hostarch.Addr representation of a pointer argument. +func (a SyscallArgument) Pointer() hostarch.Addr { + return hostarch.Addr(a.Value) } // Int returns the int32 representation of a 32-bit signed integer argument. diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 2571be60f..d6b4d2357 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -23,11 +23,11 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/usermem" ) // Host specifies the host architecture. @@ -37,7 +37,7 @@ const Host = AMD64 const ( // maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux // for a 64-bit process. - maxAddr64 usermem.Addr = (1 << 47) - usermem.PageSize + maxAddr64 hostarch.Addr = (1 << 47) - hostarch.PageSize // maxStackRand64 is the maximum randomization to apply to the stack. // It is defined by arch/x86/mm/mmap.c:stack_maxrandom_size in Linux. @@ -45,7 +45,7 @@ const ( // maxMmapRand64 is the maximum randomization to apply to the mmap // layout. It is defined by arch/x86/mm/mmap.c:arch_mmap_rnd in Linux. - maxMmapRand64 = (1 << 28) * usermem.PageSize + maxMmapRand64 = (1 << 28) * hostarch.PageSize // minGap64 is the minimum gap to leave at the top of the address space // for the stack. It is defined by arch/x86/mm/mmap.c:MIN_GAP in Linux. @@ -56,7 +56,7 @@ const ( // // The Platform {Min,Max}UserAddress() may preclude loading at this // address. See other preferredFoo comments below. - preferredPIELoadAddr usermem.Addr = maxAddr64 / 3 * 2 + preferredPIELoadAddr hostarch.Addr = maxAddr64 / 3 * 2 ) // These constants are selected as heuristics to help make the Platform's @@ -92,13 +92,13 @@ const ( // This is all "preferred" because the layout min/max address may not // allow us to select such a TopDownBase, in which case we have to fall // back to a layout that TSAN may not be happy with. - preferredTopDownAllocMin usermem.Addr = 0x7e8000000000 - preferredAllocationGap = 128 << 30 // 128 GB - preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap + preferredTopDownAllocMin hostarch.Addr = 0x7e8000000000 + preferredAllocationGap = 128 << 30 // 128 GB + preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap // minMmapRand64 is the smallest we are willing to make the // randomization to stay above preferredTopDownBaseMin. - minMmapRand64 = (1 << 26) * usermem.PageSize + minMmapRand64 = (1 << 26) * hostarch.PageSize ) // context64 represents an AMD64 context. @@ -207,12 +207,12 @@ func (c *context64) FeatureSet() *cpuid.FeatureSet { } // mmapRand returns a random adjustment for randomizing an mmap layout. -func mmapRand(max uint64) usermem.Addr { - return usermem.Addr(rand.Int63n(int64(max))).RoundDown() +func mmapRand(max uint64) hostarch.Addr { + return hostarch.Addr(rand.Int63n(int64(max))).RoundDown() } // NewMmapLayout implements Context.NewMmapLayout consistently with Linux. -func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) { +func (c *context64) NewMmapLayout(min, max hostarch.Addr, r *limits.LimitSet) (MmapLayout, error) { min, ok := min.RoundUp() if !ok { return MmapLayout{}, unix.EINVAL @@ -230,7 +230,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm // MAX_GAP in Linux. maxGap := (max / 6) * 5 - gap := usermem.Addr(stackSize.Cur) + gap := hostarch.Addr(stackSize.Cur) if gap < minGap64 { gap = minGap64 } @@ -243,7 +243,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm } topDownMin := max - gap - maxMmapRand64 - maxRand := usermem.Addr(maxMmapRand64) + maxRand := hostarch.Addr(maxMmapRand64) if topDownMin < preferredTopDownBaseMin { // Try to keep TopDownBase above preferredTopDownBaseMin by // shrinking maxRand. @@ -278,7 +278,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm } // PIELoadAddress implements Context.PIELoadAddress. -func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr { +func (c *context64) PIELoadAddress(l MmapLayout) hostarch.Addr { base := preferredPIELoadAddr max, ok := base.AddLength(maxMmapRand64) if !ok { @@ -311,7 +311,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) - return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil + return c.Native(uintptr(hostarch.ByteOrder.Uint64(buf[addr:]))), nil } // Note: x86 debug registers are missing. return c.Native(0), nil @@ -326,7 +326,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) - usermem.ByteOrder.PutUint64(buf[addr:], uint64(data)) + hostarch.ByteOrder.PutUint64(buf[addr:], uint64(data)) _, err := c.PtraceSetRegs(bytes.NewBuffer(buf)) return err } diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index 14ad9483b..348f238fd 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -22,11 +22,11 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/usermem" ) // Host specifies the host architecture. @@ -36,7 +36,7 @@ const Host = ARM64 const ( // maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux // for a 64-bit process. - maxAddr64 usermem.Addr = (1 << 48) + maxAddr64 hostarch.Addr = (1 << 48) // maxStackRand64 is the maximum randomization to apply to the stack. // It is defined by arch/arm64/mm/mmap.c:(STACK_RND_MASK << PAGE_SHIFT) in Linux. @@ -44,7 +44,7 @@ const ( // maxMmapRand64 is the maximum randomization to apply to the mmap // layout. It is defined by arch/arm64/mm/mmap.c:arch_mmap_rnd in Linux. - maxMmapRand64 = (1 << 33) * usermem.PageSize + maxMmapRand64 = (1 << 33) * hostarch.PageSize // minGap64 is the minimum gap to leave at the top of the address space // for the stack. It is defined by arch/arm64/mm/mmap.c:MIN_GAP in Linux. @@ -55,7 +55,7 @@ const ( // // The Platform {Min,Max}UserAddress() may preclude loading at this // address. See other preferredFoo comments below. - preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5 + preferredPIELoadAddr hostarch.Addr = maxAddr64 / 6 * 5 ) var ( @@ -66,13 +66,13 @@ var ( // These constants are selected as heuristics to help make the Platform's // potentially limited address space conform as closely to Linux as possible. const ( - preferredTopDownAllocMin usermem.Addr = 0x7e8000000000 - preferredAllocationGap = 128 << 30 // 128 GB - preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap + preferredTopDownAllocMin hostarch.Addr = 0x7e8000000000 + preferredAllocationGap = 128 << 30 // 128 GB + preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap // minMmapRand64 is the smallest we are willing to make the // randomization to stay above preferredTopDownBaseMin. - minMmapRand64 = (1 << 18) * usermem.PageSize + minMmapRand64 = (1 << 18) * hostarch.PageSize ) // context64 represents an ARM64 context. @@ -187,12 +187,12 @@ func (c *context64) FeatureSet() *cpuid.FeatureSet { } // mmapRand returns a random adjustment for randomizing an mmap layout. -func mmapRand(max uint64) usermem.Addr { - return usermem.Addr(rand.Int63n(int64(max))).RoundDown() +func mmapRand(max uint64) hostarch.Addr { + return hostarch.Addr(rand.Int63n(int64(max))).RoundDown() } // NewMmapLayout implements Context.NewMmapLayout consistently with Linux. -func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) { +func (c *context64) NewMmapLayout(min, max hostarch.Addr, r *limits.LimitSet) (MmapLayout, error) { min, ok := min.RoundUp() if !ok { return MmapLayout{}, unix.EINVAL @@ -210,7 +210,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm // MAX_GAP in Linux. maxGap := (max / 6) * 5 - gap := usermem.Addr(stackSize.Cur) + gap := hostarch.Addr(stackSize.Cur) if gap < minGap64 { gap = minGap64 } @@ -223,7 +223,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm } topDownMin := max - gap - maxMmapRand64 - maxRand := usermem.Addr(maxMmapRand64) + maxRand := hostarch.Addr(maxMmapRand64) if topDownMin < preferredTopDownBaseMin { // Try to keep TopDownBase above preferredTopDownBaseMin by // shrinking maxRand. @@ -258,7 +258,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm } // PIELoadAddress implements Context.PIELoadAddress. -func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr { +func (c *context64) PIELoadAddress(l MmapLayout) hostarch.Addr { base := preferredPIELoadAddr max, ok := base.AddLength(maxMmapRand64) if !ok { diff --git a/pkg/sentry/arch/auxv.go b/pkg/sentry/arch/auxv.go index 2b4c8f3fc..19ca18121 100644 --- a/pkg/sentry/arch/auxv.go +++ b/pkg/sentry/arch/auxv.go @@ -15,7 +15,7 @@ package arch import ( - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // An AuxEntry represents an entry in an ELF auxiliary vector. @@ -23,7 +23,7 @@ import ( // +stateify savable type AuxEntry struct { Key uint64 - Value usermem.Addr + Value hostarch.Addr } // An Auxv represents an ELF auxiliary vector. diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD index 0a5395267..4e4f20639 100644 --- a/pkg/sentry/arch/fpu/BUILD +++ b/pkg/sentry/arch/fpu/BUILD @@ -13,9 +13,9 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/cpuid", + "//pkg/hostarch", "//pkg/sync", "//pkg/syserror", - "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go index 3a62f51be..f0ba26736 100644 --- a/pkg/sentry/arch/fpu/fpu_amd64.go +++ b/pkg/sentry/arch/fpu/fpu_amd64.go @@ -21,9 +21,9 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // initX86FPState (defined in asm files) sets up initial state. @@ -146,11 +146,11 @@ const ( // any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section // 10.5.1.2 "SSE State") func sanitizeMXCSR(f State) { - mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:]) + mxcsr := hostarch.ByteOrder.Uint32(f[mxcsrOffset:]) initMXCSRMask.Do(func() { temp := State(alignedBytes(uint(ptraceFPRegsSize), 16)) initX86FPState(&temp[0], false /* useXsave */) - mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:]) + mxcsrMask = hostarch.ByteOrder.Uint32(temp[mxcsrMaskOffset:]) if mxcsrMask == 0 { // "If the value of the MXCSR_MASK field is 00000000H, then the // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM @@ -160,7 +160,7 @@ func sanitizeMXCSR(f State) { } }) mxcsr &= mxcsrMask - usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr) + hostarch.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr) } // PtraceGetXstateRegs implements ptrace(PTRACE_GETREGS, NT_X86_XSTATE) by @@ -177,7 +177,7 @@ func (s *State) PtraceGetXstateRegs(dst io.Writer, maxlen int, featureSet *cpuid // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE // mask. GDB relies on this: see // gdb/x86-linux-nat.c:x86_linux_read_description(). - usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], featureSet.ValidXCR0Mask()) + hostarch.ByteOrder.PutUint64(f[userXstateXCR0Offset:], featureSet.ValidXCR0Mask()) if len(f) > maxlen { f = f[:maxlen] } @@ -208,9 +208,9 @@ func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid // Force reserved bits in MXCSR to 0. This is consistent with Linux. sanitizeMXCSR(State(f)) // Users can't enable *more* XCR0 bits than what we, and the CPU, support. - xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:]) + xstateBV := hostarch.ByteOrder.Uint64(f[xstateBVOffset:]) xstateBV &= featureSet.ValidXCR0Mask() - usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV) + hostarch.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV) // Force XCOMP_BV and reserved bytes in the XSAVE header to 0. reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes] for i := range reserved { @@ -219,6 +219,11 @@ func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid return copy(*s, f), nil } +// SetMXCSR sets the MXCSR control/status register in the state. +func (s *State) SetMXCSR(mxcsr uint32) { + hostarch.ByteOrder.PutUint32((*s)[mxcsrOffset:], mxcsr) +} + // BytePointer returns a pointer to the first byte of the state. // //go:nosplit @@ -266,7 +271,7 @@ func (s *State) AfterLoad() { // What was in use? savedBV := fxsaveBV if len(old) >= xstateBVOffset+8 { - savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:]) + savedBV = hostarch.ByteOrder.Uint64(old[xstateBVOffset:]) } // Supported features must be a superset of saved features. diff --git a/pkg/sentry/arch/fpu/fpu_arm64.go b/pkg/sentry/arch/fpu/fpu_arm64.go index d2f62631d..46634661f 100644 --- a/pkg/sentry/arch/fpu/fpu_arm64.go +++ b/pkg/sentry/arch/fpu/fpu_arm64.go @@ -58,6 +58,8 @@ func (s *State) Fork() State { } // BytePointer returns a pointer to the first byte of the state. +// +//go:nosplit func (s *State) BytePointer() *byte { return &(*s)[0] } diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index 35d2e07c3..67d7edf68 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -16,7 +16,7 @@ package arch import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // SignalAct represents the action that should be taken when a signal is @@ -154,107 +154,107 @@ func (s *SignalInfo) FixSignalCodeForUser() { // PID returns the si_pid field. func (s *SignalInfo) PID() int32 { - return int32(usermem.ByteOrder.Uint32(s.Fields[0:4])) + return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4])) } // SetPID mutates the si_pid field. func (s *SignalInfo) SetPID(val int32) { - usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) + hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) } // UID returns the si_uid field. func (s *SignalInfo) UID() int32 { - return int32(usermem.ByteOrder.Uint32(s.Fields[4:8])) + return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) } // SetUID mutates the si_uid field. func (s *SignalInfo) SetUID(val int32) { - usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) + hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) } // Sigval returns the sigval field, which is aliased to both si_int and si_ptr. func (s *SignalInfo) Sigval() uint64 { - return usermem.ByteOrder.Uint64(s.Fields[8:16]) + return hostarch.ByteOrder.Uint64(s.Fields[8:16]) } // SetSigval mutates the sigval field. func (s *SignalInfo) SetSigval(val uint64) { - usermem.ByteOrder.PutUint64(s.Fields[8:16], val) + hostarch.ByteOrder.PutUint64(s.Fields[8:16], val) } // TimerID returns the si_timerid field. func (s *SignalInfo) TimerID() linux.TimerID { - return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4])) + return linux.TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4])) } // SetTimerID sets the si_timerid field. func (s *SignalInfo) SetTimerID(val linux.TimerID) { - usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) + hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) } // Overrun returns the si_overrun field. func (s *SignalInfo) Overrun() int32 { - return int32(usermem.ByteOrder.Uint32(s.Fields[4:8])) + return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) } // SetOverrun sets the si_overrun field. func (s *SignalInfo) SetOverrun(val int32) { - usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) + hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) } // Addr returns the si_addr field. func (s *SignalInfo) Addr() uint64 { - return usermem.ByteOrder.Uint64(s.Fields[0:8]) + return hostarch.ByteOrder.Uint64(s.Fields[0:8]) } // SetAddr sets the si_addr field. func (s *SignalInfo) SetAddr(val uint64) { - usermem.ByteOrder.PutUint64(s.Fields[0:8], val) + hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) } // Status returns the si_status field. func (s *SignalInfo) Status() int32 { - return int32(usermem.ByteOrder.Uint32(s.Fields[8:12])) + return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) } // SetStatus mutates the si_status field. func (s *SignalInfo) SetStatus(val int32) { - usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) + hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) } // CallAddr returns the si_call_addr field. func (s *SignalInfo) CallAddr() uint64 { - return usermem.ByteOrder.Uint64(s.Fields[0:8]) + return hostarch.ByteOrder.Uint64(s.Fields[0:8]) } // SetCallAddr mutates the si_call_addr field. func (s *SignalInfo) SetCallAddr(val uint64) { - usermem.ByteOrder.PutUint64(s.Fields[0:8], val) + hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) } // Syscall returns the si_syscall field. func (s *SignalInfo) Syscall() int32 { - return int32(usermem.ByteOrder.Uint32(s.Fields[8:12])) + return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) } // SetSyscall mutates the si_syscall field. func (s *SignalInfo) SetSyscall(val int32) { - usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) + hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) } // Arch returns the si_arch field. func (s *SignalInfo) Arch() uint32 { - return usermem.ByteOrder.Uint32(s.Fields[12:16]) + return hostarch.ByteOrder.Uint32(s.Fields[12:16]) } // SetArch mutates the si_arch field. func (s *SignalInfo) SetArch(val uint32) { - usermem.ByteOrder.PutUint32(s.Fields[12:16], val) + hostarch.ByteOrder.PutUint32(s.Fields[12:16], val) } // Band returns the si_band field. func (s *SignalInfo) Band() int64 { - return int64(usermem.ByteOrder.Uint64(s.Fields[0:8])) + return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8])) } // SetBand mutates the si_band field. @@ -262,15 +262,15 @@ func (s *SignalInfo) SetBand(val int64) { // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is // `int`. See siginfo.h. - usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) + hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) } // FD returns the si_fd field. func (s *SignalInfo) FD() uint32 { - return usermem.ByteOrder.Uint32(s.Fields[8:12]) + return hostarch.ByteOrder.Uint32(s.Fields[8:12]) } // SetFD mutates the si_fd field. func (s *SignalInfo) SetFD(val uint32) { - usermem.ByteOrder.PutUint32(s.Fields[8:12], val) + hostarch.ByteOrder.PutUint32(s.Fields[8:12], val) } diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index ee3743483..082ed92b1 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -21,10 +21,10 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" - "gvisor.dev/gvisor/pkg/usermem" ) // SignalContext64 is equivalent to struct sigcontext, the type passed as the @@ -133,7 +133,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // space on the user stack naturally caps the amount of memory the // sentry will allocate for this purpose. fpSize, _ := c.fpuFrameSize() - sp = (sp - usermem.Addr(fpSize)) & ^usermem.Addr(63) + sp = (sp - hostarch.Addr(fpSize)) & ^hostarch.Addr(63) // Construct the UContext64 now since we need its size. uc := &UContext64{ @@ -180,8 +180,8 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt ucSize := uc.SizeBytes() // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128. frameSize := int(st.Arch.Width()) + ucSize + 128 - frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8 - sp = frameBottom + usermem.Addr(frameSize) + frameBottom := (sp-hostarch.Addr(frameSize)) & ^hostarch.Addr(15) - 8 + sp = frameBottom + hostarch.Addr(frameSize) st.Bottom = sp // Prior to proceeding, figure out if the frame will exhaust the range diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 53281dcba..da71fb873 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -19,9 +19,9 @@ package arch import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" - "gvisor.dev/gvisor/pkg/usermem" ) // SignalContext64 is equivalent to struct sigcontext, the type passed as the @@ -107,8 +107,8 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // sizeof(siginfo) == 128. // R30 stores the restorer address. frameSize := ucSize + 128 - frameBottom := (sp - usermem.Addr(frameSize)) & ^usermem.Addr(15) - sp = frameBottom + usermem.Addr(frameSize) + frameBottom := (sp - hostarch.Addr(frameSize)) & ^hostarch.Addr(15) + sp = frameBottom + hostarch.Addr(frameSize) st.Bottom = sp // Prior to proceeding, figure out if the frame will exhaust the range diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go index a1eae98f9..c732c7503 100644 --- a/pkg/sentry/arch/signal_stack.go +++ b/pkg/sentry/arch/signal_stack.go @@ -17,8 +17,8 @@ package arch import ( + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" - "gvisor.dev/gvisor/pkg/usermem" ) const ( @@ -36,8 +36,8 @@ func (s SignalStack) IsEnabled() bool { } // Top returns the stack's top address. -func (s SignalStack) Top() usermem.Addr { - return usermem.Addr(s.Addr + s.Size) +func (s SignalStack) Top() hostarch.Addr { + return hostarch.Addr(s.Addr + s.Size) } // SetOnStack marks this signal stack as in use. @@ -49,8 +49,8 @@ func (s *SignalStack) SetOnStack() { } // Contains checks if the stack pointer is within this stack. -func (s *SignalStack) Contains(sp usermem.Addr) bool { - return usermem.Addr(s.Addr) < sp && sp <= usermem.Addr(s.Addr+s.Size) +func (s *SignalStack) Contains(sp hostarch.Addr) bool { + return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size) } // NativeSignalStack is a type that is equivalent to stack_t in the guest diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 5f06c751d..65a794c7c 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -16,18 +16,20 @@ package arch import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/usermem" ) -// Stack is a simple wrapper around a usermem.IO and an address. Stack +// Stack is a simple wrapper around a hostarch.IO and an address. Stack // implements marshal.CopyContext, and marshallable values can be pushed or // popped from the stack through the marshal.Marshallable interface. // // Stack is not thread-safe. type Stack struct { // Our arch info. - // We use this for automatic Native conversion of usermem.Addrs during + // We use this for automatic Native conversion of hostarch.Addrs during // Push() and Pop(). Arch Context @@ -35,7 +37,7 @@ type Stack struct { IO usermem.IO // Our current stack bottom. - Bottom usermem.Addr + Bottom hostarch.Addr // Scratch buffer used for marshalling to avoid having to repeatedly // allocate scratch memory. @@ -59,20 +61,20 @@ func (s *Stack) CopyScratchBuffer(size int) []byte { // StackBottomMagic is the special address callers must past to all stack // marshalling operations to cause the src/dst address to be computed based on // the current end of the stack. -const StackBottomMagic = ^usermem.Addr(0) // usermem.Addr(-1) +const StackBottomMagic = ^hostarch.Addr(0) // hostarch.Addr(-1) // CopyOutBytes implements marshal.CopyContext.CopyOutBytes. CopyOutBytes // computes an appropriate address based on the current end of the // stack. Callers use the sentinel address StackBottomMagic to marshal methods // to indicate this. -func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) { +func (s *Stack) CopyOutBytes(sentinel hostarch.Addr, b []byte) (int, error) { if sentinel != StackBottomMagic { panic("Attempted to copy out to stack with absolute address") } c := len(b) - n, err := s.IO.CopyOut(context.Background(), s.Bottom-usermem.Addr(c), b, usermem.IOOpts{}) + n, err := s.IO.CopyOut(context.Background(), s.Bottom-hostarch.Addr(c), b, usermem.IOOpts{}) if err == nil && n == c { - s.Bottom -= usermem.Addr(n) + s.Bottom -= hostarch.Addr(n) } return n, err } @@ -81,21 +83,21 @@ func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) { // an appropriate address based on the current end of the stack. Callers must // use the sentinel address StackBottomMagic to marshal methods to indicate // this. -func (s *Stack) CopyInBytes(sentinel usermem.Addr, b []byte) (int, error) { +func (s *Stack) CopyInBytes(sentinel hostarch.Addr, b []byte) (int, error) { if sentinel != StackBottomMagic { panic("Attempted to copy in from stack with absolute address") } n, err := s.IO.CopyIn(context.Background(), s.Bottom, b, usermem.IOOpts{}) if err == nil { - s.Bottom += usermem.Addr(n) + s.Bottom += hostarch.Addr(n) } return n, err } // Align aligns the stack to the given offset. func (s *Stack) Align(offset int) { - if s.Bottom%usermem.Addr(offset) != 0 { - s.Bottom -= (s.Bottom % usermem.Addr(offset)) + if s.Bottom%hostarch.Addr(offset) != 0 { + s.Bottom -= (s.Bottom % hostarch.Addr(offset)) } } @@ -119,16 +121,16 @@ func (s *Stack) PushNullTerminatedByteSlice(bs []byte) (int, error) { // stack. type StackLayout struct { // ArgvStart is the beginning of the argument vector. - ArgvStart usermem.Addr + ArgvStart hostarch.Addr // ArgvEnd is the end of the argument vector. - ArgvEnd usermem.Addr + ArgvEnd hostarch.Addr // EnvvStart is the beginning of the environment vector. - EnvvStart usermem.Addr + EnvvStart hostarch.Addr // EnvvEnd is the end of the environment vector. - EnvvEnd usermem.Addr + EnvvEnd hostarch.Addr } // Load pushes the given args, env and aux vector to the stack using the @@ -148,7 +150,7 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) // to be in this order. See: https://www.uclibc.org/docs/psABI-x86_64.pdf // page 29. l.EnvvEnd = s.Bottom - envAddrs := make([]usermem.Addr, len(env)) + envAddrs := make([]hostarch.Addr, len(env)) for i := len(env) - 1; i >= 0; i-- { if _, err := s.PushNullTerminatedByteSlice([]byte(env[i])); err != nil { return StackLayout{}, err @@ -159,7 +161,7 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) // Push our strings. l.ArgvEnd = s.Bottom - argAddrs := make([]usermem.Addr, len(args)) + argAddrs := make([]hostarch.Addr, len(args)) for i := len(args) - 1; i >= 0; i-- { if _, err := s.PushNullTerminatedByteSlice([]byte(args[i])); err != nil { return StackLayout{}, err @@ -178,7 +180,7 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) argvSize := s.Arch.Width() * uint(len(args)+1) envvSize := s.Arch.Width() * uint(len(env)+1) auxvSize := s.Arch.Width() * 2 * uint(len(aux)+1) - total := usermem.Addr(argvSize) + usermem.Addr(envvSize) + usermem.Addr(auxvSize) + usermem.Addr(s.Arch.Width()) + total := hostarch.Addr(argvSize) + hostarch.Addr(envvSize) + hostarch.Addr(auxvSize) + hostarch.Addr(s.Arch.Width()) expectedBottom := s.Bottom - total if expectedBottom%32 != 0 { s.Bottom -= expectedBottom % 32 @@ -188,11 +190,11 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) // NOTE: We need an extra zero here per spec. // The Push function will automatically terminate // strings and arrays with a single null value. - auxv := make([]usermem.Addr, 0, len(aux)) + auxv := make([]hostarch.Addr, 0, len(aux)) for _, a := range aux { - auxv = append(auxv, usermem.Addr(a.Key), a.Value) + auxv = append(auxv, hostarch.Addr(a.Key), a.Value) } - auxv = append(auxv, usermem.Addr(0)) + auxv = append(auxv, hostarch.Addr(0)) _, err := s.pushAddrSliceAndTerminator(auxv) if err != nil { return StackLayout{}, err diff --git a/pkg/sentry/arch/stack_unsafe.go b/pkg/sentry/arch/stack_unsafe.go index 0e478e434..f4712d58f 100644 --- a/pkg/sentry/arch/stack_unsafe.go +++ b/pkg/sentry/arch/stack_unsafe.go @@ -17,19 +17,19 @@ package arch import ( "unsafe" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" - "gvisor.dev/gvisor/pkg/usermem" ) // pushAddrSliceAndTerminator copies a slices of addresses to the stack, and // also pushes an extra null address element at the end of the slice. // // Internally, we unsafely transmute the slice type from the arch-dependent -// []usermem.Addr type, to a slice of fixed-sized ints so that we can pass it to +// []hostarch.Addr type, to a slice of fixed-sized ints so that we can pass it to // go-marshal. // // On error, the contents of the stack and the bottom cursor are undefined. -func (s *Stack) pushAddrSliceAndTerminator(src []usermem.Addr) (int, error) { +func (s *Stack) pushAddrSliceAndTerminator(src []hostarch.Addr) (int, error) { // Note: Stack grows upwards, so push the terminator first. switch s.Arch.Width() { case 8: diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index 1929e41cd..49c53452a 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -93,6 +93,7 @@ func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro // "/dev/zero (deleted)". opts.Offset = 0 opts.MappingIdentity = &fd.vfsfd + opts.SentryOwnedContent = true opts.MappingIdentity.IncRef() return nil } diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD index 71c59287c..8b38d574d 100644 --- a/pkg/sentry/devices/tundev/BUILD +++ b/pkg/sentry/devices/tundev/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/inet", diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index c43158aa4..a12eeb8e7 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -18,6 +18,7 @@ package tundev import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -89,7 +90,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg } // Validate flags. - flags, err := netstack.LinuxToTUNFlags(usermem.ByteOrder.Uint16(req.Data[:])) + flags, err := netstack.LinuxToTUNFlags(hostarch.ByteOrder.Uint16(req.Data[:])) if err != nil { return 0, err } @@ -98,7 +99,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg case linux.TUNGETIFF: var req linux.IFReq copy(req.IFName[:], fd.device.Name()) - usermem.ByteOrder.PutUint16(req.Data[:], netstack.TUNFlagsToLinux(fd.device.Flags())) + hostarch.ByteOrder.PutUint16(req.Data[:], netstack.TUNFlagsToLinux(fd.device.Flags())) _, err := req.CopyOut(t, data) return 0, err diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 420fbae34..0dc100f9b 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/abi/linux", "//pkg/amutex", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/p9", "//pkg/refs", diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD index aedcecfa1..1ce56d79f 100644 --- a/pkg/sentry/fs/anon/BUILD +++ b/pkg/sentry/fs/anon/BUILD @@ -12,9 +12,9 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/anon/anon.go b/pkg/sentry/fs/anon/anon.go index 5c421f5fb..8bda22a8e 100644 --- a/pkg/sentry/fs/anon/anon.go +++ b/pkg/sentry/fs/anon/anon.go @@ -19,9 +19,9 @@ package anon import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/usermem" ) // NewInode constructs an anonymous Inode that is not associated @@ -37,6 +37,6 @@ func NewInode(ctx context.Context) *fs.Inode { Type: fs.Anonymous, DeviceID: PseudoDevice.DeviceID(), InodeID: PseudoDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, }) } diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 58deb25fc..5aa668873 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" @@ -339,7 +340,7 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err // size is the same used by io.Copy. var copyUpBuffers = sync.Pool{ New: func() interface{} { - b := make([]byte, 8*usermem.PageSize) + b := make([]byte, 8*hostarch.PageSize) return &b }, } diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 9379a4d7b..23a3a9a2d 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -18,6 +18,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/rand", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index acbd401a0..e84ba7a5d 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -19,6 +19,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" @@ -49,7 +50,7 @@ func newCharacterDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.M return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.CharacterDevice, DeviceFileMajor: major, DeviceFileMinor: minor, @@ -60,7 +61,7 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.CharacterDevice, DeviceFileMajor: memDevMajor, DeviceFileMinor: minor, @@ -72,7 +73,7 @@ func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.M return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Directory, }) } @@ -82,7 +83,7 @@ func newSymlink(ctx context.Context, target string, msrc *fs.MountSource) *fs.In return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Symlink, }) } @@ -137,7 +138,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Directory, }) } diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index 11a2984d8..77e8d222a 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -17,6 +17,7 @@ package dev import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -110,7 +111,7 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user } // Validate flags. - flags, err := netstack.LinuxToTUNFlags(usermem.ByteOrder.Uint16(req.Data[:])) + flags, err := netstack.LinuxToTUNFlags(hostarch.ByteOrder.Uint16(req.Data[:])) if err != nil { return 0, err } @@ -119,7 +120,7 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user case linux.TUNGETIFF: var req linux.IFReq copy(req.IFName[:], n.device.Name()) - usermem.ByteOrder.PutUint16(req.Data[:], netstack.TUNFlagsToLinux(n.device.Flags())) + hostarch.ByteOrder.PutUint16(req.Data[:], netstack.TUNFlagsToLinux(n.device.Flags())) _, err := req.CopyOut(t, data) return 0, err diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index c83baf464..2120f2bad 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -40,6 +40,7 @@ go_test( "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", + "//pkg/hostarch", "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/syserror", diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index faeb3908c..ab0e9dac7 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -27,6 +27,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) func singlePipeFD() (int, error) { @@ -52,7 +54,7 @@ func mockPipeDirent(t *testing.T) *fs.Dirent { } inode := fs.NewInode(ctx, node, fs.NewMockMountSource(nil), fs.StableAttr{ Type: fs.Pipe, - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, }) return fs.NewDirent(ctx, inode, "") } diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index d388f0e92..6469cc3a9 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -76,6 +76,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/safemem", "//pkg/sentry/arch", @@ -105,6 +106,7 @@ go_test( library = ":fsutil", deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/safemem", "//pkg/sentry/contexttest", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go index 2c9446c1d..38383e730 100644 --- a/pkg/sentry/fs/fsutil/dirty_set.go +++ b/pkg/sentry/fs/fsutil/dirty_set.go @@ -18,9 +18,9 @@ import ( "math" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/usermem" ) // DirtySet maps offsets into a memmap.Mappable to DirtyInfo. It is used to @@ -215,7 +215,7 @@ func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRan if max < wbr.Start { break } - ims, err := mem.MapInternal(cseg.FileRangeOf(wbr), usermem.Read) + ims, err := mem.MapInternal(cseg.FileRangeOf(wbr), hostarch.Read) if err != nil { return err } diff --git a/pkg/sentry/fs/fsutil/dirty_set_test.go b/pkg/sentry/fs/fsutil/dirty_set_test.go index e3579c23c..48448c97c 100644 --- a/pkg/sentry/fs/fsutil/dirty_set_test.go +++ b/pkg/sentry/fs/fsutil/dirty_set_test.go @@ -18,18 +18,18 @@ import ( "reflect" "testing" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/usermem" ) func TestDirtySet(t *testing.T) { var set DirtySet - set.MarkDirty(memmap.MappableRange{0, 2 * usermem.PageSize}) - set.KeepDirty(memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize}) - set.MarkClean(memmap.MappableRange{0, 2 * usermem.PageSize}) + set.MarkDirty(memmap.MappableRange{0, 2 * hostarch.PageSize}) + set.KeepDirty(memmap.MappableRange{hostarch.PageSize, 2 * hostarch.PageSize}) + set.MarkClean(memmap.MappableRange{0, 2 * hostarch.PageSize}) want := &DirtySegmentDataSlices{ - Start: []uint64{usermem.PageSize}, - End: []uint64{2 * usermem.PageSize}, + Start: []uint64{hostarch.PageSize}, + End: []uint64{2 * hostarch.PageSize}, Values: []DirtyInfo{{Keep: true}}, } if got := set.ExportSortedSlices(); !reflect.DeepEqual(got, want) { diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index 1dc409d38..fdaceb1db 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -20,11 +20,11 @@ import ( "math" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/usermem" ) // FileRangeSet maps offsets into a memmap.Mappable to offsets into a @@ -130,7 +130,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map // MemoryFile.AllocateAndFill truncates down to a page // boundary, but FileRangeSet.Fill is supposed to // zero-fill to the end of the page in this case. - donepgaddr, ok := usermem.Addr(done).RoundUp() + donepgaddr, ok := hostarch.Addr(done).RoundUp() if donepg := uint64(donepgaddr); ok && donepg != done { dsts.DropFirst64(donepg - done) done = donepg @@ -184,7 +184,7 @@ func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) { // bytes after the new EOF on the same page are zeroed, and pages after the new // EOF are freed. func (frs *FileRangeSet) Truncate(end uint64, mf *pgalloc.MemoryFile) { - pgendaddr, ok := usermem.Addr(end).RoundUp() + pgendaddr, ok := hostarch.Addr(end).RoundUp() if ok { pgend := uint64(pgendaddr) @@ -208,7 +208,7 @@ func (frs *FileRangeSet) Truncate(end uint64, mf *pgalloc.MemoryFile) { if seg.Ok() { fr := seg.FileRange() fr.Start += end - seg.Start() - ims, err := mf.MapInternal(fr, usermem.Write) + ims, err := mf.MapInternal(fr, hostarch.Write) if err != nil { // There's no good recourse from here. This means // that we can't keep cached memory consistent with diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 54f7b7cdc..23528bf25 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -18,11 +18,11 @@ import ( "fmt" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // HostFileMapper caches mappings of an arbitrary host file descriptor. It is @@ -50,13 +50,13 @@ type HostFileMapper struct { } const ( - chunkShift = usermem.HugePageShift + chunkShift = hostarch.HugePageShift chunkSize = 1 << chunkShift chunkMask = chunkSize - 1 ) func pagesInChunk(mr memmap.MappableRange, chunkStart uint64) int32 { - return int32(mr.Intersect(memmap.MappableRange{chunkStart, chunkStart + chunkSize}).Length() / usermem.PageSize) + return int32(mr.Intersect(memmap.MappableRange{chunkStart, chunkStart + chunkSize}).Length() / hostarch.PageSize) } type mapping struct { diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index c15d8a946..e1e38b498 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -59,7 +60,7 @@ func NewHostMappable(backingFile CachedFileObject) *HostMappable { } // AddMapping implements memmap.Mappable.AddMapping. -func (h *HostMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +func (h *HostMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { // Hot path. Avoid defers. h.mu.Lock() mapped := h.mappings.AddMapping(ms, ar, offset, writable) @@ -71,7 +72,7 @@ func (h *HostMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, a } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (h *HostMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (h *HostMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { // Hot path. Avoid defers. h.mu.Lock() unmapped := h.mappings.RemoveMapping(ms, ar, offset, writable) @@ -82,18 +83,18 @@ func (h *HostMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace } // CopyMapping implements memmap.Mappable.CopyMapping. -func (h *HostMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (h *HostMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { return h.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -func (h *HostMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (h *HostMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { return []memmap.Translation{ { Source: optional, File: h, Offset: optional.Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }, }, nil } @@ -124,7 +125,7 @@ func (h *HostMappable) NotifyChangeFD() error { } // MapInternal implements memmap.File.MapInternal. -func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (h *HostMappable) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) { return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write) } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 0ed7aafa5..7856b354b 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -19,6 +19,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -622,7 +623,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { switch { case seg.Ok(): // Get internal mappings from the cache. - ims, err := mem.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) + ims, err := mem.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Read) if err != nil { unlock() return done, err @@ -647,7 +648,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // Read into the cache, then re-enter the loop to read from the // cache. reqMR := memmap.MappableRange{ - Start: uint64(usermem.Addr(gapMR.Start).RoundDown()), + Start: uint64(hostarch.Addr(gapMR.Start).RoundDown()), End: fs.OffsetPageEnd(int64(gapMR.End)), } optMR := gap.Range() @@ -729,7 +730,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error case seg.Ok() && seg.Start() < mr.End: // Get internal mappings from the cache. segMR := seg.Range().Intersect(mr) - ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write) + ims, err := mf.MapInternal(seg.FileRangeOf(segMR), hostarch.Write) if err != nil { rw.maybeGrowFile() rw.c.dataMu.Unlock() @@ -786,7 +787,7 @@ func (c *CachingInodeOperations) useHostPageCache() bool { } // AddMapping implements memmap.Mappable.AddMapping. -func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { // Hot path. Avoid defers. c.mapsMu.Lock() mapped := c.mappings.AddMapping(ms, ar, offset, writable) @@ -808,7 +809,7 @@ func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.Mappi } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { // Hot path. Avoid defers. c.mapsMu.Lock() unmapped := c.mappings.RemoveMapping(ms, ar, offset, writable) @@ -836,12 +837,12 @@ func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Ma } // CopyMapping implements memmap.Mappable.CopyMapping. -func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { return c.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { // Hot path. Avoid defer. if c.useHostPageCache() { mr := optional @@ -853,7 +854,7 @@ func (c *CachingInodeOperations) Translate(ctx context.Context, required, option Source: mr, File: c, Offset: mr.Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }, }, nil } @@ -885,7 +886,7 @@ func (c *CachingInodeOperations) Translate(ctx context.Context, required, option segMR := seg.Range().Intersect(optional) // TODO(jamieliu): Make Translations writable even if writability is // not required if already kept-dirty by another writable translation. - perms := usermem.AccessType{ + perms := hostarch.AccessType{ Read: true, Execute: true, } @@ -1050,7 +1051,7 @@ func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) { // MapInternal implements memmap.File.MapInternal. This is used when we // directly map an underlying host fd and CachingInodeOperations is used as the // memmap.File during translation. -func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) { return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write) } diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go index 1547584c5..e107c3096 100644 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ b/pkg/sentry/fs/fsutil/inode_cached_test.go @@ -20,6 +20,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -249,7 +250,7 @@ func (f *sliceBackingFile) Allocate(ctx context.Context, offset int64, length in type noopMappingSpace struct{} // Invalidate implements memmap.MappingSpace.Invalidate. -func (noopMappingSpace) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) { +func (noopMappingSpace) Invalidate(ar hostarch.AddrRange, opts memmap.InvalidateOpts) { } func anonInode(ctx context.Context) *fs.Inode { @@ -259,14 +260,14 @@ func anonInode(ctx context.Context) *fs.Inode { }, 0), }, fs.NewPseudoMountSource(ctx), fs.StableAttr{ Type: fs.Anonymous, - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, }) } func pagesOf(bs ...byte) []byte { - buf := make([]byte, 0, len(bs)*usermem.PageSize) + buf := make([]byte, 0, len(bs)*hostarch.PageSize) for _, b := range bs { - buf = append(buf, bytes.Repeat([]byte{b}, usermem.PageSize)...) + buf = append(buf, bytes.Repeat([]byte{b}, hostarch.PageSize)...) } return buf } @@ -292,28 +293,28 @@ func TestRead(t *testing.T) { // expects to only cache mapped pages), then call Translate to force it to // be cached. var ms noopMappingSpace - ar := usermem.AddrRange{usermem.PageSize, 2 * usermem.PageSize} - if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil { + ar := hostarch.AddrRange{hostarch.PageSize, 2 * hostarch.PageSize} + if err := iops.AddMapping(ctx, ms, ar, hostarch.PageSize, true); err != nil { t.Fatalf("AddMapping got %v, want nil", err) } - mr := memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize} - if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil { + mr := memmap.MappableRange{hostarch.PageSize, 2 * hostarch.PageSize} + if _, err := iops.Translate(ctx, mr, mr, hostarch.Read); err != nil { t.Fatalf("Translate got %v, want nil", err) } - if cached := iops.cache.Span(); cached != usermem.PageSize { - t.Errorf("SpanRange got %d, want %d", cached, usermem.PageSize) + if cached := iops.cache.Span(); cached != hostarch.PageSize { + t.Errorf("SpanRange got %d, want %d", cached, hostarch.PageSize) } // Try to read 4 pages. The first and third pages should be read directly // from the "file", the second page should be read from the cache, and only // 3 pages (the size of the file) should be readable. - rbuf := make([]byte, 4*usermem.PageSize) + rbuf := make([]byte, 4*hostarch.PageSize) dst := usermem.BytesIOSequence(rbuf) n, err := iops.Read(ctx, file, dst, 0) - if n != 3*usermem.PageSize || (err != nil && err != io.EOF) { - t.Fatalf("Read got (%d, %v), want (%d, nil or EOF)", n, err, 3*usermem.PageSize) + if n != 3*hostarch.PageSize || (err != nil && err != io.EOF) { + t.Fatalf("Read got (%d, %v), want (%d, nil or EOF)", n, err, 3*hostarch.PageSize) } - rbuf = rbuf[:3*usermem.PageSize] + rbuf = rbuf[:3*hostarch.PageSize] // Did we get the bytes we expect? if !bytes.Equal(rbuf, buf) { @@ -323,7 +324,7 @@ func TestRead(t *testing.T) { // Delete the memory mapping before iops.Release(). The cached page will // either be evicted by ctx's pgalloc.MemoryFile, or dropped by // iops.Release(). - iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true) + iops.RemoveMapping(ctx, ms, ar, hostarch.PageSize, true) } func TestWrite(t *testing.T) { @@ -348,25 +349,25 @@ func TestWrite(t *testing.T) { // CachingInodeOperations expects to only cache mapped pages), then call // Translate to force them to be cached. var ms noopMappingSpace - ar := usermem.AddrRange{usermem.PageSize, 3 * usermem.PageSize} - if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil { + ar := hostarch.AddrRange{hostarch.PageSize, 3 * hostarch.PageSize} + if err := iops.AddMapping(ctx, ms, ar, hostarch.PageSize, true); err != nil { t.Fatalf("AddMapping got %v, want nil", err) } - defer iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true) - mr := memmap.MappableRange{usermem.PageSize, 3 * usermem.PageSize} - if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil { + defer iops.RemoveMapping(ctx, ms, ar, hostarch.PageSize, true) + mr := memmap.MappableRange{hostarch.PageSize, 3 * hostarch.PageSize} + if _, err := iops.Translate(ctx, mr, mr, hostarch.Read); err != nil { t.Fatalf("Translate got %v, want nil", err) } - if cached := iops.cache.Span(); cached != 2*usermem.PageSize { - t.Errorf("SpanRange got %d, want %d", cached, 2*usermem.PageSize) + if cached := iops.cache.Span(); cached != 2*hostarch.PageSize { + t.Errorf("SpanRange got %d, want %d", cached, 2*hostarch.PageSize) } // Write to the first 2 pages. wbuf := pagesOf('e', 'f') src := usermem.BytesIOSequence(wbuf) n, err := iops.Write(ctx, src, 0) - if n != 2*usermem.PageSize || err != nil { - t.Fatalf("Write got (%d, %v), want (%d, nil)", n, err, 2*usermem.PageSize) + if n != 2*hostarch.PageSize || err != nil { + t.Fatalf("Write got (%d, %v), want (%d, nil)", n, err, 2*hostarch.PageSize) } // The first page should have been written directly, since it was not cached. @@ -382,7 +383,7 @@ func TestWrite(t *testing.T) { } // Now the second page should have been written as well. - copy(want[usermem.PageSize:], pagesOf('f')) + copy(want[hostarch.PageSize:], pagesOf('f')) if !bytes.Equal(buf, want) { t.Errorf("File contents are %v, want %v", buf, want) } diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index b210e0e7e..c4a069832 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fd", + "//pkg/hostarch", "//pkg/log", "//pkg/p9", "//pkg/refs", diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go index cffc756cc..d6bff3f40 100644 --- a/pkg/sentry/fs/gofer/attr.go +++ b/pkg/sentry/fs/gofer/attr.go @@ -17,11 +17,11 @@ package gofer import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/usermem" ) // getattr returns the 9p attributes of the p9.File. On success, Mode, Size, and RDev @@ -98,7 +98,7 @@ func bsize(pattr p9.Attr) int64 { // Some files, particularly those that are not on a local file system, // may have no clue of their block size. Better not to report something // misleading or buggy and have a safe default. - return usermem.PageSize + return hostarch.PageSize } // ntype returns an fs.InodeType from 9p attributes. diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 0b3d0617f..46a2dc47d 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -384,8 +384,16 @@ func (c *ConnectedEndpoint) CloseUnread() {} // SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { - // gVisor does not permit setting of SO_SNDBUF for host backed unix domain - // sockets. + // gVisor does not permit setting of SO_SNDBUF for host backed unix + // domain sockets. + return atomic.LoadInt64(&c.sndbuf) +} + +// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize. +func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_RCVBUF for host backed unix + // domain sockets. Receive buffer does not have any effect for unix + // sockets and we claim to be the same as send buffer. return atomic.LoadInt64(&c.sndbuf) } diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index fb81d903d..1b83643db 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" @@ -216,7 +217,7 @@ func (i *Inotify) Ioctl(ctx context.Context, _ *File, io usermem.IO, args arch.S n += uint32(e.sizeOf()) } var buf [4]byte - usermem.ByteOrder.PutUint32(buf[:], n) + hostarch.ByteOrder.PutUint32(buf[:], n) _, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{}) return 0, err diff --git a/pkg/sentry/fs/inotify_event.go b/pkg/sentry/fs/inotify_event.go index 686e1b1cd..399aff1ed 100644 --- a/pkg/sentry/fs/inotify_event.go +++ b/pkg/sentry/fs/inotify_event.go @@ -19,6 +19,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/usermem" ) @@ -100,10 +101,10 @@ func (e *Event) sizeOf() int { // construct the output. We use a buffer allocated ahead of time for // performance. buf must be at least inotifyEventBaseSize bytes. func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) { - usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd)) - usermem.ByteOrder.PutUint32(buf[4:], e.mask) - usermem.ByteOrder.PutUint32(buf[8:], e.cookie) - usermem.ByteOrder.PutUint32(buf[12:], e.len) + hostarch.ByteOrder.PutUint32(buf[0:], uint32(e.wd)) + hostarch.ByteOrder.PutUint32(buf[4:], e.mask) + hostarch.ByteOrder.PutUint32(buf[8:], e.cookie) + hostarch.ByteOrder.PutUint32(buf[12:], e.len) writeLen := 0 diff --git a/pkg/sentry/fs/offset.go b/pkg/sentry/fs/offset.go index 53b5df175..3a8c97d8f 100644 --- a/pkg/sentry/fs/offset.go +++ b/pkg/sentry/fs/offset.go @@ -17,14 +17,14 @@ package fs import ( "math" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // OffsetPageEnd returns the file offset rounded up to the nearest // page boundary. OffsetPageEnd panics if rounding up causes overflow, // which shouldn't be possible given that offset is an int64. func OffsetPageEnd(offset int64) uint64 { - end, ok := usermem.Addr(offset).RoundUp() + end, ok := hostarch.Addr(offset).RoundUp() if !ok { panic("impossible overflow") } diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 01a1235b8..f96f5a3e5 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -19,11 +19,11 @@ import ( "strings" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // The virtual filesystem implements an overlay configuration. For a high-level @@ -274,7 +274,7 @@ func (o *overlayEntry) markDirectoryDirty() { } // AddMapping implements memmap.Mappable.AddMapping. -func (o *overlayEntry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +func (o *overlayEntry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { o.mapsMu.Lock() defer o.mapsMu.Unlock() if err := o.inodeLocked().Mappable().AddMapping(ctx, ms, ar, offset, writable); err != nil { @@ -285,7 +285,7 @@ func (o *overlayEntry) AddMapping(ctx context.Context, ms memmap.MappingSpace, a } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (o *overlayEntry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (o *overlayEntry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { o.mapsMu.Lock() defer o.mapsMu.Unlock() o.inodeLocked().Mappable().RemoveMapping(ctx, ms, ar, offset, writable) @@ -293,7 +293,7 @@ func (o *overlayEntry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace } // CopyMapping implements memmap.Mappable.CopyMapping. -func (o *overlayEntry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (o *overlayEntry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { o.mapsMu.Lock() defer o.mapsMu.Unlock() if err := o.inodeLocked().Mappable().CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil { @@ -304,7 +304,7 @@ func (o *overlayEntry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, } // Translate implements memmap.Mappable.Translate. -func (o *overlayEntry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (o *overlayEntry) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { o.dataMu.RLock() defer o.dataMu.RUnlock() return o.inodeLocked().Mappable().Translate(ctx, required, optional, at) diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index b8b2281a8..7af7e0b45 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -30,6 +30,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go index e6171dd1d..24426b225 100644 --- a/pkg/sentry/fs/proc/exec_args.go +++ b/pkg/sentry/fs/proc/exec_args.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -113,7 +114,7 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen defer m.DecUsers(ctx) // Figure out the bounds of the exec arg we are trying to read. - var execArgStart, execArgEnd usermem.Addr + var execArgStart, execArgEnd hostarch.Addr switch f.arg { case cmdlineExecArg: execArgStart, execArgEnd = m.ArgvStart(), m.ArgvEnd() @@ -172,8 +173,8 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208 // we'll return one page total between argv and envp because of the // above page restrictions. - if lengthEnvv > usermem.PageSize-len(buf) { - lengthEnvv = usermem.PageSize - len(buf) + if lengthEnvv > hostarch.PageSize-len(buf) { + lengthEnvv = hostarch.PageSize - len(buf) } // Make a new buffer to fit the whole thing tmp := make([]byte, length+lengthEnvv) diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go index d2859a4c2..78132f7a5 100644 --- a/pkg/sentry/fs/proc/inode.go +++ b/pkg/sentry/fs/proc/inode.go @@ -17,13 +17,13 @@ package proc import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange @@ -125,7 +125,7 @@ func newProcInode(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo sattr := fs.StableAttr{ DeviceID: device.ProcDevice.DeviceID(), InodeID: device.ProcDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: typ, } if t != nil { diff --git a/pkg/sentry/fs/proc/meminfo.go b/pkg/sentry/fs/proc/meminfo.go index 91617267d..7d975d333 100644 --- a/pkg/sentry/fs/proc/meminfo.go +++ b/pkg/sentry/fs/proc/meminfo.go @@ -19,10 +19,10 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange @@ -53,7 +53,7 @@ func (d *meminfoData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) anon := snapshot.Anonymous + snapshot.Tmpfs file := snapshot.PageCache + snapshot.Mapped // We don't actually have active/inactive LRUs, so just make up numbers. - activeFile := (file / 2) &^ (usermem.PageSize - 1) + activeFile := (file / 2) &^ (hostarch.PageSize - 1) inactiveFile := file - activeFile var buf bytes.Buffer diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 203cfa061..91c35eea9 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" @@ -35,7 +36,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange @@ -367,10 +367,10 @@ func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([] ) if len(rt.GatewayAddr) == header.IPv4AddressSize { flags |= linux.RTF_GATEWAY - gw = usermem.ByteOrder.Uint32(rt.GatewayAddr) + gw = hostarch.ByteOrder.Uint32(rt.GatewayAddr) } if len(rt.DstAddr) == header.IPv4AddressSize { - prefix = usermem.ByteOrder.Uint32(rt.DstAddr) + prefix = hostarch.ByteOrder.Uint32(rt.DstAddr) } l := fmt.Sprintf( "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d", @@ -520,7 +520,7 @@ func networkToHost16(n uint16) uint16 { // binary.BigEndian.Uint16() require a read of binary.BigEndian and an // interface method call, defeating inlining. buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)} - return usermem.ByteOrder.Uint16(buf[:]) + return hostarch.ByteOrder.Uint16(buf[:]) } func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { @@ -542,14 +542,14 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { // __be32 which is a typedef for an unsigned int, and is printed with // %X. This means that for a little-endian machine, Linux prints the // least-significant byte of the address first. To emulate this, we first - // invert the byte order for the address using usermem.ByteOrder.Uint32, + // invert the byte order for the address using hostarch.ByteOrder.Uint32, // which makes it have the equivalent encoding to a __be32 on a little // endian machine. Note that this operation is a no-op on a big endian // machine. Then similar to Linux, we format it with %X, which will print // the most-significant byte of the __be32 address first, which is now // actually the least-significant byte of the original address in // linux.SockAddrInet.Addr on little endian machines, due to the conversion. - addr := usermem.ByteOrder.Uint32(a.Addr[:]) + addr := hostarch.ByteOrder.Uint32(a.Addr[:]) fmt.Fprintf(w, "%08X:%04X ", addr, port) case linux.AF_INET6: @@ -559,10 +559,10 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { } port := networkToHost16(a.Port) - addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4]) - addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8]) - addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12]) - addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16]) + addr0 := hostarch.ByteOrder.Uint32(a.Addr[0:4]) + addr1 := hostarch.ByteOrder.Uint32(a.Addr[4:8]) + addr2 := hostarch.ByteOrder.Uint32(a.Addr[8:12]) + addr3 := hostarch.ByteOrder.Uint32(a.Addr[12:16]) fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port) } } diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 21338d912..713b81e08 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/proc/device", diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go index 6121f0e95..b01688b1d 100644 --- a/pkg/sentry/fs/proc/seqfile/seqfile.go +++ b/pkg/sentry/fs/proc/seqfile/seqfile.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" @@ -131,7 +132,7 @@ func NewSeqFileInode(ctx context.Context, source SeqSource, msrc *fs.MountSource sattr := fs.StableAttr{ DeviceID: device.ProcDevice.DeviceID(), InodeID: device.ProcDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.SpecialFile, } return fs.NewInode(ctx, iops, msrc, sattr) diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index bbe282c03..1d09afdd7 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" @@ -76,7 +77,7 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir sattr := fs.StableAttr{ DeviceID: device.ProcDevice.DeviceID(), InodeID: device.ProcDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.SpecialFile, } return fs.NewInode(ctx, tm, msrc, sattr) @@ -136,7 +137,7 @@ func (f *tcpMemFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequen f.tcpMemInode.mu.Lock() defer f.tcpMemInode.mu.Unlock() - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) size, err := readSize(f.tcpMemInode.dir, f.tcpMemInode.s) if err != nil { return 0, err @@ -192,7 +193,7 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f sattr := fs.StableAttr{ DeviceID: device.ProcDevice.DeviceID(), InodeID: device.ProcDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.SpecialFile, } return fs.NewInode(ctx, ts, msrc, sattr) @@ -264,7 +265,7 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque // Only consider size of one memory page for input for performance reasons. // We are only reading if it's zero or not anyway. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) @@ -294,7 +295,7 @@ func newTCPRecoveryInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack sattr := fs.StableAttr{ DeviceID: device.ProcDevice.DeviceID(), InodeID: device.ProcDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.SpecialFile, } return fs.NewInode(ctx, ts, msrc, sattr) @@ -354,7 +355,7 @@ func (f *tcpRecoveryFile) Write(ctx context.Context, _ *fs.File, src usermem.IOS if src.NumBytes() == 0 { return 0, nil } - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) @@ -413,7 +414,7 @@ func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stac sattr := fs.StableAttr{ DeviceID: device.ProcDevice.DeviceID(), InodeID: device.ProcDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.SpecialFile, } return fs.NewInode(ctx, ipf, msrc, sattr) @@ -486,7 +487,7 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO // Only consider size of one memory page for input for performance reasons. // We are only reading if it's zero or not anyway. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) @@ -524,7 +525,7 @@ func newPortRangeInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) sattr := fs.StableAttr{ DeviceID: device.ProcDevice.DeviceID(), InodeID: device.ProcDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.SpecialFile, } return fs.NewInode(ctx, ipf, msrc, sattr) @@ -589,7 +590,7 @@ func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSe // Only consider size of one memory page for input for performance // reasons. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) ports := make([]int32, 2) n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts) diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index f43d6c221..ae5ed25f9 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" @@ -469,7 +470,7 @@ func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen defer mm.DecUsers(ctx) // Buffer the read data because of MM locks buf := make([]byte, dst.NumBytes()) - n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + n, readErr := mm.CopyIn(ctx, hostarch.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) if n > 0 { if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { return 0, syserror.EFAULT @@ -632,7 +633,7 @@ func (s *taskStatData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) rss = mm.ResidentSetSize() } }) - fmt.Fprintf(&buf, "%d %d ", vss, rss/usermem.PageSize) + fmt.Fprintf(&buf, "%d %d ", vss, rss/hostarch.PageSize) // rsslim. fmt.Fprintf(&buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur) @@ -684,7 +685,7 @@ func (s *statmData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([ }) var buf bytes.Buffer - fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize) + fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize) return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statmData)(nil)}}, 0 } @@ -939,8 +940,8 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc buf := make([]byte, size) for i, e := range auxv { - usermem.ByteOrder.PutUint64(buf[16*i:], e.Key) - usermem.ByteOrder.PutUint64(buf[16*i+8:], uint64(e.Value)) + hostarch.ByteOrder.PutUint64(buf[16*i:], e.Key) + hostarch.ByteOrder.PutUint64(buf[16*i+8:], uint64(e.Value)) } n, err := dst.CopyOut(ctx, buf[offset:]) @@ -1020,7 +1021,7 @@ func (f *oomScoreAdjFile) Write(ctx context.Context, _ *fs.File, src usermem.IOS } // Limit input size so as not to impact performance if input size is large. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go index 2bc9485d8..30d5ad4cf 100644 --- a/pkg/sentry/fs/proc/uid_gid_map.go +++ b/pkg/sentry/fs/proc/uid_gid_map.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -132,7 +133,7 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u // the system page size, and the write must be performed at the start of // the file ..." - user_namespaces(7) srclen := src.NumBytes() - if srclen >= usermem.PageSize || offset != 0 { + if srclen >= hostarch.PageSize || offset != 0 { return 0, syserror.EINVAL } b := make([]byte, srclen) diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index a51d00d86..4a3d9636b 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -14,13 +14,13 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/socket/unix/transport", "//pkg/sync", "//pkg/syserror", - "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/fs/ramfs/tree.go b/pkg/sentry/fs/ramfs/tree.go index dfc9d3453..0ace636c9 100644 --- a/pkg/sentry/fs/ramfs/tree.go +++ b/pkg/sentry/fs/ramfs/tree.go @@ -20,9 +20,9 @@ import ( "strings" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/usermem" ) // MakeDirectoryTree constructs a ramfs tree of all directories containing @@ -71,7 +71,7 @@ func emptyDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { return fs.NewInode(ctx, dir, msrc, fs.StableAttr{ DeviceID: anon.PseudoDevice.DeviceID(), InodeID: anon.PseudoDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Directory, }) } diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD index f2e8b9932..fdbc5f912 100644 --- a/pkg/sentry/fs/sys/BUILD +++ b/pkg/sentry/fs/sys/BUILD @@ -14,11 +14,11 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/kernel", - "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/sys/sys.go b/pkg/sentry/fs/sys/sys.go index 0891645e4..101779a7a 100644 --- a/pkg/sentry/fs/sys/sys.go +++ b/pkg/sentry/fs/sys/sys.go @@ -17,16 +17,16 @@ package sys import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/usermem" ) func newFile(ctx context.Context, node fs.InodeOperations, msrc *fs.MountSource) *fs.Inode { sattr := fs.StableAttr{ DeviceID: sysfsDevice.DeviceID(), InodeID: sysfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.SpecialFile, } return fs.NewInode(ctx, node, msrc, sattr) @@ -37,7 +37,7 @@ func newDir(ctx context.Context, msrc *fs.MountSource, contents map[string]*fs.I return fs.NewInode(ctx, d, msrc, fs.StableAttr{ DeviceID: sysfsDevice.DeviceID(), InodeID: sysfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.SpecialDirectory, }) } diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index d16cdb4df..c7977a217 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -8,6 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index 46511a6ac..c8ebe256c 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -124,7 +125,7 @@ func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.I } if val := atomic.SwapUint64(&t.val, 0); val != 0 { var buf [sizeofUint64]byte - usermem.ByteOrder.PutUint64(buf[:], val) + hostarch.ByteOrder.PutUint64(buf[:], val) if _, err := dst.CopyOut(ctx, buf[:]); err != nil { // Linux does not undo consuming the number of expirations even if // writing to userspace fails. diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index b521a86a2..90398376a 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/safemem", "//pkg/sentry/device", "//pkg/sentry/fs", @@ -42,6 +43,7 @@ go_test( library = ":tmpfs", deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/fs", "//pkg/sentry/kernel/contexttest", "//pkg/sentry/usage", diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go index d4d613ea9..1718f9372 100644 --- a/pkg/sentry/fs/tmpfs/file_test.go +++ b/pkg/sentry/fs/tmpfs/file_test.go @@ -19,6 +19,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/usage" @@ -31,7 +32,7 @@ func newFileInode(ctx context.Context) *fs.Inode { return fs.NewInode(ctx, iops, m, fs.StableAttr{ DeviceID: tmpfsDevice.DeviceID(), InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.RegularFile, }) } diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index ad4aea282..f4de8c968 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -125,7 +126,7 @@ func NewMemfdInode(ctx context.Context, allowSeals bool) *fs.Inode { Type: fs.RegularFile, DeviceID: tmpfsDevice.DeviceID(), InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, }) } @@ -392,7 +393,7 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { switch { case seg.Ok(): // Get internal mappings. - ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) + ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Read) if err != nil { return done, err } @@ -463,7 +464,7 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) // // See Linux, mm/filemap.c:generic_perform_write() and // mm/shmem.c:shmem_write_begin(). - if pgstart := int64(usermem.Addr(rw.f.attr.Size).RoundDown()); end > pgstart { + if pgstart := int64(hostarch.Addr(rw.f.attr.Size).RoundDown()); end > pgstart { end = pgstart } if end <= rw.offset { @@ -483,8 +484,8 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) mf := rw.f.kernel.MemoryFile() // Page-aligned mr for when we need to allocate memory. RoundUp can't // overflow since end is an int64. - pgstartaddr := usermem.Addr(rw.offset).RoundDown() - pgendaddr, _ := usermem.Addr(end).RoundUp() + pgstartaddr := hostarch.Addr(rw.offset).RoundDown() + pgendaddr, _ := hostarch.Addr(end).RoundUp() pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)} var done uint64 @@ -494,7 +495,7 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) switch { case seg.Ok(): // Get internal mappings. - ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write) + ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Write) if err != nil { return done, err } @@ -527,7 +528,7 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) } // AddMapping implements memmap.Mappable.AddMapping. -func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { f.mapsMu.Lock() defer f.mapsMu.Unlock() @@ -544,7 +545,7 @@ func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingS pagesBefore := f.writableMappingPages // ar is guaranteed to be page aligned per memmap.Mappable. - f.writableMappingPages += uint64(ar.Length() / usermem.PageSize) + f.writableMappingPages += uint64(ar.Length() / hostarch.PageSize) if f.writableMappingPages < pagesBefore { panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, f.writableMappingPages)) @@ -555,7 +556,7 @@ func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingS } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { f.mapsMu.Lock() defer f.mapsMu.Unlock() @@ -565,7 +566,7 @@ func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Mappi pagesBefore := f.writableMappingPages // ar is guaranteed to be page aligned per memmap.Mappable. - f.writableMappingPages -= uint64(ar.Length() / usermem.PageSize) + f.writableMappingPages -= uint64(ar.Length() / hostarch.PageSize) if f.writableMappingPages > pagesBefore { panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, f.writableMappingPages)) @@ -574,12 +575,12 @@ func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Mappi } // CopyMapping implements memmap.Mappable.CopyMapping. -func (f *fileInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (f *fileInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { return f.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -func (f *fileInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (f *fileInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { f.dataMu.Lock() defer f.dataMu.Unlock() @@ -612,7 +613,7 @@ func (f *fileInodeOperations) Translate(ctx context.Context, required, optional Source: segMR, File: mf, Offset: seg.FileRangeOf(segMR).Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }) translatedEnd = segMR.End } diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index cf4ed5de0..577052888 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" @@ -28,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) var fsInfo = fs.Info{ @@ -41,8 +41,8 @@ var fsInfo = fs.Info{ // chosen to ensure that BlockSize * Blocks does not overflow int64 (which // applications may also handle incorrectly). // TODO(b/29637826): allow configuring a tmpfs size and enforce it. - TotalBlocks: math.MaxInt64 / usermem.PageSize, - FreeBlocks: math.MaxInt64 / usermem.PageSize, + TotalBlocks: math.MaxInt64 / hostarch.PageSize, + FreeBlocks: math.MaxInt64 / hostarch.PageSize, } // rename implements fs.InodeOperations.Rename for tmpfs nodes. @@ -99,7 +99,7 @@ func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwn return fs.NewInode(ctx, d, msrc, fs.StableAttr{ DeviceID: tmpfsDevice.DeviceID(), InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Directory, }) } @@ -232,7 +232,7 @@ func (d *Dir) newCreateOps() *ramfs.CreateOps { return fs.NewInode(ctx, iops, dir.MountSource, fs.StableAttr{ DeviceID: tmpfsDevice.DeviceID(), InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.RegularFile, }), nil }, @@ -281,7 +281,7 @@ func NewSymlink(ctx context.Context, target string, owner fs.FileOwner, msrc *fs return fs.NewInode(ctx, s, msrc, fs.StableAttr{ DeviceID: tmpfsDevice.DeviceID(), InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Symlink, }) } @@ -311,7 +311,7 @@ func NewSocket(ctx context.Context, socket transport.BoundEndpoint, owner fs.Fil return fs.NewInode(ctx, s, msrc, fs.StableAttr{ DeviceID: tmpfsDevice.DeviceID(), InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Socket, }) } @@ -348,7 +348,7 @@ func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, return fs.NewInode(ctx, fifoIops, msrc, fs.StableAttr{ DeviceID: tmpfsDevice.DeviceID(), InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Pipe, }) } diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index e6d0eb359..86ada820e 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -17,6 +17,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index c2da80bc2..13c9dbe7d 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -122,7 +123,7 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode { // TODO(b/75267214): Since ptsDevice must be shared between // different mounts, we must not assign fixed numbers. InodeID: ptsDevice.NextIno(), - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, Type: fs.Directory, }) } diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD new file mode 100644 index 000000000..37efb641a --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/BUILD @@ -0,0 +1,48 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +licenses(["notice"]) + +go_template_instance( + name = "dir_refs", + out = "dir_refs.go", + package = "cgroupfs", + prefix = "dir", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "dir", + }, +) + +go_library( + name = "cgroupfs", + srcs = [ + "base.go", + "cgroupfs.go", + "cpu.go", + "cpuacct.go", + "cpuset.go", + "dir_refs.go", + "job.go", + "memory.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/coverage", + "//pkg/log", + "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/usage", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go new file mode 100644 index 000000000..0f54888d8 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -0,0 +1,261 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + "sort" + "strconv" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// controllerCommon implements kernel.CgroupController. +// +// Must call init before use. +// +// +stateify savable +type controllerCommon struct { + ty kernel.CgroupControllerType + fs *filesystem +} + +func (c *controllerCommon) init(ty kernel.CgroupControllerType, fs *filesystem) { + c.ty = ty + c.fs = fs +} + +// Type implements kernel.CgroupController.Type. +func (c *controllerCommon) Type() kernel.CgroupControllerType { + return kernel.CgroupControllerType(c.ty) +} + +// HierarchyID implements kernel.CgroupController.HierarchyID. +func (c *controllerCommon) HierarchyID() uint32 { + return c.fs.hierarchyID +} + +// NumCgroups implements kernel.CgroupController.NumCgroups. +func (c *controllerCommon) NumCgroups() uint64 { + return atomic.LoadUint64(&c.fs.numCgroups) +} + +// Enabled implements kernel.CgroupController.Enabled. +// +// Controllers are currently always enabled. +func (c *controllerCommon) Enabled() bool { + return true +} + +// Filesystem implements kernel.CgroupController.Filesystem. +func (c *controllerCommon) Filesystem() *vfs.Filesystem { + return c.fs.VFSFilesystem() +} + +// RootCgroup implements kernel.CgroupController.RootCgroup. +func (c *controllerCommon) RootCgroup() kernel.Cgroup { + return c.fs.rootCgroup() +} + +// controller is an interface for common functionality related to all cgroups. +// It is an extension of the public cgroup interface, containing cgroup +// functionality private to cgroupfs. +type controller interface { + kernel.CgroupController + + // AddControlFiles should extend the contents map with inodes representing + // control files defined by this controller. + AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode) +} + +// cgroupInode implements kernel.CgroupImpl and kernfs.Inode. +// +// +stateify savable +type cgroupInode struct { + dir + fs *filesystem + + // ts is the list of tasks in this cgroup. The kernel is responsible for + // removing tasks from this list before they're destroyed, so any tasks on + // this list are always valid. + // + // ts, and cgroup membership in general is protected by fs.tasksMu. + ts map[*kernel.Task]struct{} +} + +var _ kernel.CgroupImpl = (*cgroupInode)(nil) + +func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode { + c := &cgroupInode{ + fs: fs, + ts: make(map[*kernel.Task]struct{}), + } + + contents := make(map[string]kernfs.Inode) + contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c}) + contents["tasks"] = fs.newControllerFile(ctx, creds, &tasksData{c}) + + for _, ctl := range fs.controllers { + ctl.AddControlFiles(ctx, creds, c, contents) + } + + c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555)) + c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + c.dir.InitRefs() + c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents)) + + atomic.AddUint64(&fs.numCgroups, 1) + + return c +} + +func (c *cgroupInode) HierarchyID() uint32 { + return c.fs.hierarchyID +} + +// Controllers implements kernel.CgroupImpl.Controllers. +func (c *cgroupInode) Controllers() []kernel.CgroupController { + return c.fs.kcontrollers +} + +// Enter implements kernel.CgroupImpl.Enter. +func (c *cgroupInode) Enter(t *kernel.Task) { + c.fs.tasksMu.Lock() + c.ts[t] = struct{}{} + c.fs.tasksMu.Unlock() +} + +// Leave implements kernel.CgroupImpl.Leave. +func (c *cgroupInode) Leave(t *kernel.Task) { + c.fs.tasksMu.Lock() + delete(c.ts, t) + c.fs.tasksMu.Unlock() +} + +func sortTIDs(tids []kernel.ThreadID) { + sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] }) +} + +// +stateify savable +type cgroupProcsData struct { + *cgroupInode +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + + pgids := make(map[kernel.ThreadID]struct{}) + + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + for task := range d.ts { + // Map dedups pgid, since iterating over all tasks produces multiple + // entries for the group leaders. + if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 { + pgids[pgid] = struct{}{} + } + } + + pgidList := make([]kernel.ThreadID, 0, len(pgids)) + for pgid, _ := range pgids { + pgidList = append(pgidList, pgid) + } + sortTIDs(pgidList) + + for _, pgid := range pgidList { + fmt.Fprintf(buf, "%d\n", pgid) + } + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. + return src.NumBytes(), nil +} + +// +stateify savable +type tasksData struct { + *cgroupInode +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + + var pids []kernel.ThreadID + + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + for task := range d.ts { + if pid := currPidns.IDOfTask(task); pid != 0 { + pids = append(pids, pid) + } + } + sortTIDs(pids) + + for _, pid := range pids { + fmt.Fprintf(buf, "%d\n", pid) + } + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. + return src.NumBytes(), nil +} + +// parseInt64FromString interprets src as string encoding a int64 value, and +// returns the parsed value. +func parseInt64FromString(ctx context.Context, src usermem.IOSequence, offset int64) (val, len int64, err error) { + const maxInt64StrLen = 20 // i.e. len(fmt.Sprintf("%d", math.MinInt64)) == 20 + + t := kernel.TaskFromContext(ctx) + src = src.DropFirst64(offset) + + buf := t.CopyScratchBuffer(maxInt64StrLen) + n, err := src.CopyIn(ctx, buf) + if err != nil { + return 0, int64(n), err + } + buf = buf[:n] + + val, err = strconv.ParseInt(string(buf), 10, 64) + if err != nil { + // Note: This also handles zero-len writes if offset is beyond the end + // of src, or src is empty. + ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err) + return 0, int64(n), syserror.EINVAL + } + + return val, int64(n), nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go new file mode 100644 index 000000000..bd3e69757 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -0,0 +1,425 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cgroupfs implements cgroupfs. +// +// A cgroup is a collection of tasks on the system, organized into a tree-like +// structure similar to a filesystem directory tree. In fact, each cgroup is +// represented by a directory on cgroupfs, and is manipulated through control +// files in the directory. +// +// All cgroups on a system are organized into hierarchies. Hierarchies are a +// distinct tree of cgroups, with a common set of controllers. One or more +// cgroupfs mounts may point to each hierarchy. These mounts provide a common +// view into the same tree of cgroups. +// +// A controller (also known as a "resource controller", or a cgroup "subsystem") +// determines the behaviour of each cgroup. +// +// In addition to cgroupfs, the kernel has a cgroup registry that tracks +// system-wide state related to cgroups such as active hierarchies and the +// controllers associated with them. +// +// Since cgroupfs doesn't allow hardlinks, there is a unique mapping between +// cgroupfs dentries and inodes. +// +// # Synchronization +// +// Cgroup hierarchy creation and destruction is protected by the +// kernel.CgroupRegistry.mu. Once created, a hierarchy's set of controllers, the +// filesystem associated with it, and the root cgroup for the hierarchy are +// immutable. +// +// Membership of tasks within cgroups is protected by +// cgroupfs.filesystem.tasksMu. Tasks also maintain a set of all cgroups they're +// in, and this list is protected by Task.mu. +// +// Lock order: +// +// kernel.CgroupRegistry.mu +// cgroupfs.filesystem.mu +// Task.mu +// cgroupfs.filesystem.tasksMu. +package cgroupfs + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + // Name is the default filesystem name. + Name = "cgroup" + readonlyFileMode = linux.FileMode(0444) + writableFileMode = linux.FileMode(0644) + defaultMaxCachedDentries = uint64(1000) +) + +const ( + controllerCPU = kernel.CgroupControllerType("cpu") + controllerCPUAcct = kernel.CgroupControllerType("cpuacct") + controllerCPUSet = kernel.CgroupControllerType("cpuset") + controllerJob = kernel.CgroupControllerType("job") + controllerMemory = kernel.CgroupControllerType("memory") +) + +var allControllers = []kernel.CgroupControllerType{ + controllerCPU, + controllerCPUAcct, + controllerCPUSet, + controllerJob, + controllerMemory, +} + +// SupportedMountOptions is the set of supported mount options for cgroupfs. +var SupportedMountOptions = []string{"all", "cpu", "cpuacct", "cpuset", "job", "memory"} + +// FilesystemType implements vfs.FilesystemType. +// +// +stateify savable +type FilesystemType struct{} + +// InternalData contains internal data passed in to the cgroupfs mount via +// vfs.GetFilesystemOptions.InternalData. +// +// +stateify savable +type InternalData struct { + DefaultControlValues map[string]int64 +} + +// filesystem implements vfs.FilesystemImpl. +// +// +stateify savable +type filesystem struct { + kernfs.Filesystem + devMinor uint32 + + // hierarchyID is the id the cgroup registry assigns to this hierarchy. Has + // the value kernel.InvalidCgroupHierarchyID until the FS is fully + // initialized. + // + // hierarchyID is immutable after initialization. + hierarchyID uint32 + + // controllers and kcontrollers are both the list of controllers attached to + // this cgroupfs. Both lists are the same set of controllers, but typecast + // to different interfaces for convenience. Both must stay in sync, and are + // immutable. + controllers []controller + kcontrollers []kernel.CgroupController + + numCgroups uint64 // Protected by atomic ops. + + root *kernfs.Dentry + + // tasksMu serializes task membership changes across all cgroups within a + // filesystem. + tasksMu sync.RWMutex `state:"nosave"` +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// Release implements vfs.FilesystemType.Release. +func (FilesystemType) Release(ctx context.Context) {} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + + var wantControllers []kernel.CgroupControllerType + if _, ok := mopts["cpu"]; ok { + delete(mopts, "cpu") + wantControllers = append(wantControllers, controllerCPU) + } + if _, ok := mopts["cpuacct"]; ok { + delete(mopts, "cpuacct") + wantControllers = append(wantControllers, controllerCPUAcct) + } + if _, ok := mopts["cpuset"]; ok { + delete(mopts, "cpuset") + wantControllers = append(wantControllers, controllerCPUSet) + } + if _, ok := mopts["job"]; ok { + delete(mopts, "job") + wantControllers = append(wantControllers, controllerJob) + } + if _, ok := mopts["memory"]; ok { + delete(mopts, "memory") + wantControllers = append(wantControllers, controllerMemory) + } + if _, ok := mopts["all"]; ok { + if len(wantControllers) > 0 { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers) + return nil, nil, syserror.EINVAL + } + + delete(mopts, "all") + wantControllers = allControllers + } + + if len(wantControllers) == 0 { + // Specifying no controllers implies all controllers. + wantControllers = allControllers + } + + if len(mopts) != 0 { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + + k := kernel.KernelFromContext(ctx) + r := k.CgroupRegistry() + + // "It is not possible to mount the same controller against multiple + // cgroup hierarchies. For example, it is not possible to mount both + // the cpu and cpuacct controllers against one hierarchy, and to mount + // the cpu controller alone against another hierarchy." - man cgroups(7) + // + // Is there a hierarchy available with all the controllers we want? If so, + // this mount is a view into the same hierarchy. + // + // Note: we're guaranteed to have at least one requested controller, since + // no explicit controller name implies all controllers. + if vfsfs := r.FindHierarchy(wantControllers); vfsfs != nil { + fs := vfsfs.Impl().(*filesystem) + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID) + fs.root.IncRef() + return vfsfs, fs.root.VFSDentry(), nil + } + + // No existing hierarchy with the exactly controllers found. Make a new + // one. Note that it's possible this mount creation is unsatisfiable, if one + // or more of the requested controllers are already on existing + // hierarchies. We'll find out about such collisions when we try to register + // the new hierarchy later. + fs := &filesystem{ + devMinor: devMinor, + } + fs.MaxCachedDentries = maxCachedDentries + fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + + var defaults map[string]int64 + if opts.InternalData != nil { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults) + defaults = opts.InternalData.(*InternalData).DefaultControlValues + } + + for _, ty := range wantControllers { + var c controller + switch ty { + case controllerCPU: + c = newCPUController(fs, defaults) + case controllerCPUAcct: + c = newCPUAcctController(fs) + case controllerCPUSet: + c = newCPUSetController(fs) + case controllerJob: + c = newJobController(fs) + case controllerMemory: + c = newMemoryController(fs, defaults) + default: + panic(fmt.Sprintf("Unreachable: unknown cgroup controller %q", ty)) + } + fs.controllers = append(fs.controllers, c) + } + + if len(defaults) != 0 { + // Internal data is always provided at sentry startup and unused values + // indicate a problem with the sandbox config. Fail fast. + panic(fmt.Sprintf("cgroupfs.FilesystemType.GetFilesystem: unknown internal mount data: %v", defaults)) + } + + // Controllers usually appear in alphabetical order when displayed. Sort it + // here now, so it never needs to be sorted elsewhere. + sort.Slice(fs.controllers, func(i, j int) bool { return fs.controllers[i].Type() < fs.controllers[j].Type() }) + fs.kcontrollers = make([]kernel.CgroupController, 0, len(fs.controllers)) + for _, c := range fs.controllers { + fs.kcontrollers = append(fs.kcontrollers, c) + } + + root := fs.newCgroupInode(ctx, creds) + var rootD kernfs.Dentry + rootD.InitRoot(&fs.Filesystem, root) + fs.root = &rootD + + // Register controllers. The registry may be modified concurrently, so if we + // get an error, we raced with someone else who registered the same + // controllers first. + hid, err := r.Register(fs.kcontrollers) + if err != nil { + ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err) + rootD.DecRef(ctx) + fs.VFSFilesystem().DecRef(ctx) + return nil, nil, syserror.EBUSY + } + fs.hierarchyID = hid + + // Move all existing tasks to the root of the new hierarchy. + k.PopulateNewCgroupHierarchy(fs.rootCgroup()) + + return fs.VFSFilesystem(), rootD.VFSDentry(), nil +} + +func (fs *filesystem) rootCgroup() kernel.Cgroup { + return kernel.Cgroup{ + Dentry: fs.root, + CgroupImpl: fs.root.Inode().(kernel.CgroupImpl), + } +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release(ctx context.Context) { + k := kernel.KernelFromContext(ctx) + r := k.CgroupRegistry() + + if fs.hierarchyID != kernel.InvalidCgroupHierarchyID { + k.ReleaseCgroupHierarchy(fs.hierarchyID) + r.Unregister(fs.hierarchyID) + } + + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release(ctx) +} + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + var cnames []string + for _, c := range fs.controllers { + cnames = append(cnames, string(c.Type())) + } + return strings.Join(cnames, ",") +} + +// +stateify savable +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil +} + +// dir implements kernfs.Inode for a generic cgroup resource controller +// directory. Specific controllers extend this to add their own functionality. +// +// +stateify savable +type dir struct { + dirRefs + kernfs.InodeAlwaysValid + kernfs.InodeAttrs + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir. + kernfs.OrderedChildren + implStatFS + + locks vfs.FileLocks +} + +// Keep implements kernfs.Inode.Keep. +func (*dir) Keep() bool { + return true +} + +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. +func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Open implements kernfs.Inode.Open. +func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) + if err != nil { + return nil, err + } + return fd.VFSFileDescription(), nil +} + +// DecRef implements kernfs.Inode.DecRef. +func (d *dir) DecRef(ctx context.Context) { + d.dirRefs.DecRef(func() { d.Destroy(ctx) }) +} + +// StatFS implements kernfs.Inode.StatFS. +func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil +} + +// controllerFile represents a generic control file that appears within a cgroup +// directory. +// +// +stateify savable +type controllerFile struct { + kernfs.DynamicBytesFile +} + +func (fs *filesystem) newControllerFile(ctx context.Context, creds *auth.Credentials, data vfs.DynamicBytesSource) kernfs.Inode { + f := &controllerFile{} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, readonlyFileMode) + return f +} + +func (fs *filesystem) newControllerWritableFile(ctx context.Context, creds *auth.Credentials, data vfs.WritableDynamicBytesSource) kernfs.Inode { + f := &controllerFile{} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, writableFileMode) + return f +} + +// staticControllerFile represents a generic control file that appears within a +// cgroup directory which always returns the same data when read. +// staticControllerFiles are not writable. +// +// +stateify savable +type staticControllerFile struct { + kernfs.DynamicBytesFile + vfs.StaticData +} + +// Note: We let the caller provide the mode so that static files may be used to +// fake both readable and writable control files. However, static files are +// effectively readonly, as attempting to write to them will return EIO +// regardless of the mode. +func (fs *filesystem) newStaticControllerFile(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, data string) kernfs.Inode { + f := &staticControllerFile{StaticData: vfs.StaticData{Data: data}} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), f, mode) + return f +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpu.go b/pkg/sentry/fsimpl/cgroupfs/cpu.go new file mode 100644 index 000000000..24d86a277 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpu.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// +stateify savable +type cpuController struct { + controllerCommon + + // CFS bandwidth control parameters, values in microseconds. + cfsPeriod int64 + cfsQuota int64 + + // CPU shares, values should be (num core * 1024). + shares int64 +} + +var _ controller = (*cpuController)(nil) + +func newCPUController(fs *filesystem, defaults map[string]int64) *cpuController { + // Default values for controller parameters from Linux. + c := &cpuController{ + cfsPeriod: 100000, + cfsQuota: -1, + shares: 1024, + } + + if val, ok := defaults["cpu.cfs_period_us"]; ok { + c.cfsPeriod = val + delete(defaults, "cpu.cfs_period_us") + } + if val, ok := defaults["cpu.cfs_quota_us"]; ok { + c.cfsQuota = val + delete(defaults, "cpu.cfs_quota_us") + } + if val, ok := defaults["cpu.shares"]; ok { + c.shares = val + delete(defaults, "cpu.shares") + } + + c.controllerCommon.init(controllerCPU, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpuController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["cpu.cfs_period_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsPeriod)) + contents["cpu.cfs_quota_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsQuota)) + contents["cpu.shares"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.shares)) +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go new file mode 100644 index 000000000..d4104a00e --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go @@ -0,0 +1,114 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +// +stateify savable +type cpuacctController struct { + controllerCommon +} + +var _ controller = (*cpuacctController)(nil) + +func newCPUAcctController(fs *filesystem) *cpuacctController { + c := &cpuacctController{} + c.controllerCommon.init(controllerCPUAcct, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpuacctController) AddControlFiles(ctx context.Context, creds *auth.Credentials, cg *cgroupInode, contents map[string]kernfs.Inode) { + cpuacctCG := &cpuacctCgroup{cg} + contents["cpuacct.stat"] = c.fs.newControllerFile(ctx, creds, &cpuacctStatData{cpuacctCG}) + contents["cpuacct.usage"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageData{cpuacctCG}) + contents["cpuacct.usage_user"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageUserData{cpuacctCG}) + contents["cpuacct.usage_sys"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageSysData{cpuacctCG}) +} + +// +stateify savable +type cpuacctCgroup struct { + *cgroupInode +} + +func (c *cpuacctCgroup) collectCPUStats() usage.CPUStats { + var cs usage.CPUStats + c.fs.tasksMu.RLock() + // Note: This isn't very accurate, since the tasks are potentially + // still running as we accumulate their stats. + for t := range c.ts { + cs.Accumulate(t.CPUStats()) + } + c.fs.tasksMu.RUnlock() + return cs +} + +// +stateify savable +type cpuacctStatData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "user %d\n", linux.ClockTFromDuration(cs.UserTime)) + fmt.Fprintf(buf, "system %d\n", linux.ClockTFromDuration(cs.SysTime)) + return nil +} + +// +stateify savable +type cpuacctUsageData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()+cs.SysTime.Nanoseconds()) + return nil +} + +// +stateify savable +type cpuacctUsageUserData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageUserData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()) + return nil +} + +// +stateify savable +type cpuacctUsageSysData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageSysData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.SysTime.Nanoseconds()) + return nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go new file mode 100644 index 000000000..ac547f8e2 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go @@ -0,0 +1,39 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// +stateify savable +type cpusetController struct { + controllerCommon +} + +var _ controller = (*cpusetController)(nil) + +func newCPUSetController(fs *filesystem) *cpusetController { + c := &cpusetController{} + c.controllerCommon.init(controllerCPUSet, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + // This controller is currently intentionally empty. +} diff --git a/pkg/sentry/fsimpl/cgroupfs/job.go b/pkg/sentry/fsimpl/cgroupfs/job.go new file mode 100644 index 000000000..48919c338 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/job.go @@ -0,0 +1,64 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/usermem" +) + +// +stateify savable +type jobController struct { + controllerCommon + id int64 +} + +var _ controller = (*jobController)(nil) + +func newJobController(fs *filesystem) *jobController { + c := &jobController{} + c.controllerCommon.init(controllerJob, fs) + return c +} + +func (c *jobController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["job.id"] = c.fs.newControllerWritableFile(ctx, creds, &jobIDData{c: c}) +} + +// +stateify savable +type jobIDData struct { + c *jobController +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *jobIDData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%d\n", d.c.id) + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *jobIDData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + val, n, err := parseInt64FromString(ctx, src, offset) + if err != nil { + return n, err + } + d.c.id = val + return n, nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go new file mode 100644 index 000000000..485c98376 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/memory.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + "math" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +// +stateify savable +type memoryController struct { + controllerCommon + + limitBytes int64 +} + +var _ controller = (*memoryController)(nil) + +func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController { + c := &memoryController{ + // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which + // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact + // value isn't very important. + limitBytes: math.MaxInt64, + } + if val, ok := defaults["memory.limit_in_bytes"]; ok { + c.limitBytes = val + delete(defaults, "memory.limit_in_bytes") + } + c.controllerCommon.init(controllerMemory, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{}) + contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes)) +} + +// +stateify savable +type memoryUsageInBytesData struct{} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *memoryUsageInBytesData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // TODO(b/183151557): This is a giant hack, we're using system-wide + // accounting since we know there is only one cgroup. + k := kernel.KernelFromContext(ctx) + mf := k.MemoryFile() + mf.UpdateUsage() + _, totalBytes := usage.MemoryAccounting.Copy() + + fmt.Fprintf(buf, "%d\n", totalBytes) + return nil +} diff --git a/pkg/sentry/fsimpl/eventfd/BUILD b/pkg/sentry/fsimpl/eventfd/BUILD index bcb01bb08..c09fdc7f9 100644 --- a/pkg/sentry/fsimpl/eventfd/BUILD +++ b/pkg/sentry/fsimpl/eventfd/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fdnotifier", + "//pkg/hostarch", "//pkg/log", "//pkg/sentry/vfs", "//pkg/syserror", diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 30bd05357..4f79cfcb7 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -188,7 +189,7 @@ func (efd *EventFileDescription) read(ctx context.Context, dst usermem.IOSequenc efd.queue.Notify(waiter.WritableEvents) var buf [8]byte - usermem.ByteOrder.PutUint64(buf[:], val) + hostarch.ByteOrder.PutUint64(buf[:], val) _, err := dst.CopyOut(ctx, buf[:]) return err } @@ -196,7 +197,7 @@ func (efd *EventFileDescription) read(ctx context.Context, dst usermem.IOSequenc // Preconditions: Must be called with efd.mu locked. func (efd *EventFileDescription) hostWriteLocked(val uint64) error { var buf [8]byte - usermem.ByteOrder.PutUint64(buf[:], val) + hostarch.ByteOrder.PutUint64(buf[:], val) _, err := unix.Write(efd.hostfd, buf[:]) if err == unix.EWOULDBLOCK { return syserror.ErrWouldBlock @@ -209,7 +210,7 @@ func (efd *EventFileDescription) write(ctx context.Context, src usermem.IOSequen if _, err := src.CopyIn(ctx, buf[:]); err != nil { return err } - val := usermem.ByteOrder.Uint64(buf[:]) + val := hostarch.ByteOrder.Uint64(buf[:]) return efd.Signal(val) } diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 155c0f56d..3a4777fbe 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -46,6 +46,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/marshal", "//pkg/refs", @@ -75,6 +76,7 @@ go_test( library = ":fuse", deps = [ "//pkg/abi/linux", + "//pkg/hostarch", "//pkg/marshal", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 23ce91849..66ea889f9 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -20,11 +20,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // ReadInPages sends FUSE_READ requests for the size after round it up to @@ -43,10 +43,10 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui } // Round up to a multiple of page size. - readSize, _ := usermem.PageRoundUp(uint64(size)) + readSize, _ := hostarch.PageRoundUp(uint64(size)) // One request cannnot exceed either maxRead or maxPages. - maxPages := fs.conn.maxRead >> usermem.PageShift + maxPages := fs.conn.maxRead >> hostarch.PageShift if maxPages > uint32(fs.conn.maxPages) { maxPages = uint32(fs.conn.maxPages) } @@ -54,9 +54,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui var outs [][]byte var sizeRead uint32 - // readSize is a multiple of usermem.PageSize. + // readSize is a multiple of hostarch.PageSize. // Always request bytes as a multiple of pages. - pagesRead, pagesToRead := uint32(0), uint32(readSize>>usermem.PageShift) + pagesRead, pagesToRead := uint32(0), uint32(readSize>>hostarch.PageShift) // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation. in := linux.FUSEReadIn{ @@ -76,8 +76,8 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui pagesCanRead = maxPages } - in.Offset = off + (uint64(pagesRead) << usermem.PageShift) - in.Size = pagesCanRead << usermem.PageShift + in.Offset = off + (uint64(pagesRead) << hostarch.PageShift) + in.Size = pagesCanRead << hostarch.PageShift // TODO(gvisor.dev/issue/3247): support async read. @@ -159,7 +159,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, } // One request cannnot exceed either maxWrite or maxPages. - maxWrite := uint32(fs.conn.maxPages) << usermem.PageShift + maxWrite := uint32(fs.conn.maxPages) << hostarch.PageShift if maxWrite > fs.conn.maxWrite { maxWrite = fs.conn.maxWrite } @@ -188,8 +188,8 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, // Limit the write size to one page. // Note that the bigWrites flag is obsolete, // latest libfuse always sets it on. - if !fs.conn.bigWrites && toWrite > usermem.PageSize { - toWrite = usermem.PageSize + if !fs.conn.bigWrites && toWrite > hostarch.PageSize { + toWrite = hostarch.PageSize } // Limit the write size to maxWrite. diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index 10fb9d7d2..8a72489fa 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -19,10 +19,10 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/usermem" ) // fuseInitRes is a variable-length wrapper of linux.FUSEInitOut. The FUSE @@ -45,29 +45,29 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out := &r.initOut // Introduced before FUSE kernel version 7.13. - out.Major = uint32(usermem.ByteOrder.Uint32(src[:4])) + out.Major = uint32(hostarch.ByteOrder.Uint32(src[:4])) src = src[4:] - out.Minor = uint32(usermem.ByteOrder.Uint32(src[:4])) + out.Minor = uint32(hostarch.ByteOrder.Uint32(src[:4])) src = src[4:] - out.MaxReadahead = uint32(usermem.ByteOrder.Uint32(src[:4])) + out.MaxReadahead = uint32(hostarch.ByteOrder.Uint32(src[:4])) src = src[4:] - out.Flags = uint32(usermem.ByteOrder.Uint32(src[:4])) + out.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) src = src[4:] - out.MaxBackground = uint16(usermem.ByteOrder.Uint16(src[:2])) + out.MaxBackground = uint16(hostarch.ByteOrder.Uint16(src[:2])) src = src[2:] - out.CongestionThreshold = uint16(usermem.ByteOrder.Uint16(src[:2])) + out.CongestionThreshold = uint16(hostarch.ByteOrder.Uint16(src[:2])) src = src[2:] - out.MaxWrite = uint32(usermem.ByteOrder.Uint32(src[:4])) + out.MaxWrite = uint32(hostarch.ByteOrder.Uint32(src[:4])) src = src[4:] // Introduced in FUSE kernel version 7.23. if len(src) >= 4 { - out.TimeGran = uint32(usermem.ByteOrder.Uint32(src[:4])) + out.TimeGran = uint32(hostarch.ByteOrder.Uint32(src[:4])) src = src[4:] } // Introduced in FUSE kernel version 7.28. if len(src) >= 2 { - out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) + out.MaxPages = uint16(hostarch.ByteOrder.Uint16(src[:2])) src = src[2:] } _ = src // Remove unused warning. diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go index 2c0cc0f4e..b0bab0066 100644 --- a/pkg/sentry/fsimpl/fuse/utils_test.go +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -24,7 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) func setup(t *testing.T) *testutil.System { @@ -82,12 +83,12 @@ func (t *testPayload) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (t *testPayload) MarshalBytes(dst []byte) { - usermem.ByteOrder.PutUint32(dst[:4], t.data) + hostarch.ByteOrder.PutUint32(dst[:4], t.data) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (t *testPayload) UnmarshalBytes(src []byte) { - *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} + *t = testPayload{data: hostarch.ByteOrder.Uint32(src[:4])} } // Packed implements marshal.Marshallable.Packed. @@ -106,17 +107,17 @@ func (t *testPayload) UnmarshalUnsafe(src []byte) { } // CopyOutN implements marshal.Marshallable.CopyOutN. -func (t *testPayload) CopyOutN(task marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { +func (t *testPayload) CopyOutN(task marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { panic("not implemented") } // CopyOut implements marshal.Marshallable.CopyOut. -func (t *testPayload) CopyOut(task marshal.CopyContext, addr usermem.Addr) (int, error) { +func (t *testPayload) CopyOut(task marshal.CopyContext, addr hostarch.Addr) (int, error) { panic("not implemented") } // CopyIn implements marshal.Marshallable.CopyIn. -func (t *testPayload) CopyIn(task marshal.CopyContext, addr usermem.Addr) (int, error) { +func (t *testPayload) CopyIn(task marshal.CopyContext, addr hostarch.Addr) (int, error) { panic("not implemented") } diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 807b6ed1f..6d5258a9b 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/fd", "//pkg/fdnotifier", "//pkg/fspath", + "//pkg/hostarch", "//pkg/log", "//pkg/p9", "//pkg/refs", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 9da01cba3..177e42649 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -28,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) func (d *dentry) isDir() bool { @@ -98,7 +98,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), - blockSize: usermem.PageSize, // arbitrary + blockSize: hostarch.PageSize, // arbitrary atime: now, mtime: now, ctime: now, diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 43c3c5a2d..4b5621043 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -141,21 +141,8 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp ** return } ds := **dsp - // Only go through calling dentry.checkCachingLocked() (which requires - // re-locking renameMu) if we actually have any dentries with zero refs. - checkAny := false - for i := range ds { - if atomic.LoadInt64(&ds[i].refs) == 0 { - checkAny = true - break - } - } - if checkAny { - fs.renameMu.Lock() - for _, d := range ds { - d.checkCachingLocked(ctx) - } - fs.renameMu.Unlock() + for _, d := range ds { + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } putDentrySlice(*dsp) } @@ -166,7 +153,7 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[] return } for _, d := range **ds { - d.checkCachingLocked(ctx) + d.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } fs.renameMu.Unlock() putDentrySlice(*ds) @@ -339,8 +326,10 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir } parent.cacheNewChildLocked(child, name) // For now, child has 0 references, so our caller should call - // child.checkCachingLocked(). + // child.checkCachingLocked(). parent gained a ref so we should also call + // parent.checkCachingLocked() so it can be removed from the cache if needed. *ds = appendDentry(*ds, child) + *ds = appendDentry(*ds, parent) return child, nil } @@ -723,6 +712,8 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op } } d.IncRef() + // Call d.checkCachingLocked() so it can be removed from the cache if needed. + ds = appendDentry(ds, d) return &d.vfsd, nil } @@ -744,6 +735,8 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa return nil, err } d.IncRef() + // Call d.checkCachingLocked() so it can be removed from the cache if needed. + ds = appendDentry(ds, d) return &d.vfsd, nil } @@ -782,7 +775,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { // If the parent is a setgid directory, use the parent's GID // rather than the caller's and enable setgid. kgid := creds.EffectiveKGID @@ -802,6 +795,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kuid: creds.EffectiveKUID, kgid: creds.EffectiveKGID, }) + *ds = appendDentry(*ds, parent) } if fs.opts.interop != InteropModeShared { parent.incLinks() @@ -855,6 +849,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kgid: creds.EffectiveKGID, endpoint: opts.Endpoint, }) + *ds = appendDentry(*ds, parent) return nil case linux.S_IFIFO: parent.createSyntheticChildLocked(&createSyntheticOpts{ @@ -864,6 +859,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kgid: creds.EffectiveKGID, pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) + *ds = appendDentry(*ds, parent) return nil } // Retain error from gofer if synthetic file cannot be created internally. @@ -912,6 +908,8 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf start.IncRef() defer start.DecRef(ctx) unlock() + // start is intentionally not added to ds (which would remove it from the + // cache) because doing so regresses performance in practice. return start.open(ctx, rp, &opts) } @@ -965,6 +963,8 @@ afterTrailingSymlink: child.IncRef() defer child.DecRef(ctx) unlock() + // child is intentionally not added to ds (which would remove it from the + // cache) because doing so regresses performance in practice. return child.open(ctx, rp, &opts) } @@ -1212,6 +1212,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } // Insert the dentry into the tree. d.cacheNewChildLocked(child, name) + *ds = appendDentry(*ds, d) if d.cachedMetadataAuthoritative() { d.touchCMtime() d.dirents = nil @@ -1403,6 +1404,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa oldParent.decRefNoCaching() ds = appendDentry(ds, oldParent) newParent.IncRef() + ds = appendDentry(ds, newParent) if renamed.isSynthetic() { oldParent.syntheticChildren-- newParent.syntheticChildren++ @@ -1546,6 +1548,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath if d.isSocket() { if !d.isSynthetic() { d.IncRef() + ds = appendDentry(ds, d) return &endpoint{ dentry: d, path: opts.Addr, diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 692da02c1..fb42c5f62 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -18,15 +18,17 @@ // Lock order: // regularFileFD/directoryFD.mu // filesystem.renameMu -// dentry.dirMu -// filesystem.syncMu -// dentry.metadataMu -// *** "memmap.Mappable locks" below this point -// dentry.mapsMu -// *** "memmap.Mappable locks taken by Translate" below this point -// dentry.handleMu -// dentry.dataMu -// filesystem.inoMu +// dentry.cachingMu +// filesystem.cacheMu +// dentry.dirMu +// filesystem.syncMu +// dentry.metadataMu +// *** "memmap.Mappable locks" below this point +// dentry.mapsMu +// *** "memmap.Mappable locks taken by Translate" below this point +// dentry.handleMu +// dentry.dataMu +// filesystem.inoMu // specialFileFD.mu // specialFileFD.bufMu // @@ -44,6 +46,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" refs_vfs1 "gvisor.dev/gvisor/pkg/refs" @@ -60,7 +63,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" - "gvisor.dev/gvisor/pkg/usermem" ) // Name is the default filesystem name. @@ -140,7 +142,8 @@ type filesystem struct { // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) // cachedDentriesLen is the number of dentries in cachedDentries. These fields - // are protected by renameMu. + // are protected by cacheMu. + cacheMu sync.Mutex `state:"nosave"` cachedDentries dentryList cachedDentriesLen uint64 @@ -620,11 +623,11 @@ func (fs *filesystem) Release(ctx context.Context) { // the reference count on every synthetic dentry. Synthetic dentries have one // reference for existence that should be dropped during filesystem.Release. // -// Precondition: d.fs.renameMu is locked. +// Precondition: d.fs.renameMu is locked for writing. func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { if d.isSynthetic() { d.decRefNoCaching() - d.checkCachingLocked(ctx) + d.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } if d.isDir() { var children []*dentry @@ -682,9 +685,13 @@ type dentry struct { // deleted. deleted is accessed using atomic memory operations. deleted uint32 + // cachingMu is used to synchronize concurrent dentry caching attempts on + // this dentry. + cachingMu sync.Mutex `state:"nosave"` + // If cached is true, dentryEntry links dentry into // filesystem.cachedDentries. cached and dentryEntry are protected by - // filesystem.renameMu. + // cachingMu. cached bool dentryEntry @@ -872,7 +879,7 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), - blockSize: usermem.PageSize, + blockSize: hostarch.PageSize, readFD: -1, writeFD: -1, mmapFD: -1, @@ -980,36 +987,63 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } // Preconditions: !d.isSynthetic(). +// Preconditions: d.metadataMu is locked. +func (d *dentry) refreshSizeLocked(ctx context.Context) error { + d.handleMu.RLock() + + if d.writeFD < 0 { + d.handleMu.RUnlock() + // Ask the gofer if we don't have a host FD. + return d.updateFromGetattrLocked(ctx) + } + + var stat unix.Statx_t + err := unix.Statx(int(d.writeFD), "", unix.AT_EMPTY_PATH, unix.STATX_SIZE, &stat) + d.handleMu.RUnlock() // must be released before updateSizeLocked() + if err != nil { + return err + } + d.updateSizeLocked(stat.Size) + return nil +} + +// Preconditions: !d.isSynthetic(). func (d *dentry) updateFromGetattr(ctx context.Context) error { - // Use d.readFile or d.writeFile, which represent 9P fids that have been + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + return d.updateFromGetattrLocked(ctx) +} + +// Preconditions: +// * !d.isSynthetic(). +// * d.metadataMu is locked. +func (d *dentry) updateFromGetattrLocked(ctx context.Context) error { + // Use d.readFile or d.writeFile, which represent 9P FIDs that have been // opened, in preference to d.file, which represents a 9P fid that has not. // This may be significantly more efficient in some implementations. Prefer // d.writeFile over d.readFile since some filesystem implementations may // update a writable handle's metadata after writes to that handle, without // making metadata updates immediately visible to read-only handles // representing the same file. - var ( - file p9file - handleMuRLocked bool - ) - // d.metadataMu must be locked *before* we getAttr so that we do not end up - // updating stale attributes in d.updateFromP9AttrsLocked(). - d.metadataMu.Lock() - defer d.metadataMu.Unlock() d.handleMu.RLock() - if !d.writeFile.isNil() { + handleMuRLocked := true + var file p9file + switch { + case !d.writeFile.isNil(): file = d.writeFile - handleMuRLocked = true - } else if !d.readFile.isNil() { + case !d.readFile.isNil(): file = d.readFile - handleMuRLocked = true - } else { + default: file = d.file d.handleMu.RUnlock() + handleMuRLocked = false } + _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask()) if handleMuRLocked { - d.handleMu.RUnlock() + d.handleMu.RUnlock() // must be released before updateFromP9AttrsLocked() } if err != nil { return err @@ -1104,24 +1138,27 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs defer d.metadataMu.Unlock() // As with Linux, if the UID, GID, or file size is changing, we have to - // clear permission bits. Note that when set, clearSGID causes - // permissions to be updated, but does not modify stat.Mask, as - // modification would cause an extra inotify flag to be set. - clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) || - stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) || + // clear permission bits. Note that when set, clearSGID may cause + // permissions to be updated. + clearSGID := (stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid)) || + (stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid)) || stat.Mask&linux.STATX_SIZE != 0 if clearSGID { if stat.Mask&linux.STATX_MODE != 0 { stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode))) } else { - stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode))) + oldMode := atomic.LoadUint32(&d.mode) + if updatedMode := vfs.ClearSUIDAndSGID(oldMode); updatedMode != oldMode { + stat.Mode = uint16(updatedMode) + stat.Mask |= linux.STATX_MODE + } } } if !d.isSynthetic() { if stat.Mask != 0 { if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID, + Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, GID: stat.Mask&linux.STATX_GID != 0, Size: stat.Mask&linux.STATX_SIZE != 0, @@ -1156,7 +1193,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - if stat.Mask&linux.STATX_MODE != 0 || clearSGID { + if stat.Mask&linux.STATX_MODE != 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } if stat.Mask&linux.STATX_UID != 0 { @@ -1217,8 +1254,8 @@ func (d *dentry) updateSizeLocked(newSize uint64) { // so we can't race with Write or another truncate.) d.dataMu.Unlock() if d.size < oldSize { - oldpgend, _ := usermem.PageRoundUp(oldSize) - newpgend, _ := usermem.PageRoundUp(d.size) + oldpgend, _ := hostarch.PageRoundUp(oldSize) + newpgend, _ := hostarch.PageRoundUp(d.size) if oldpgend != newpgend { d.mapsMu.Lock() d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ @@ -1312,9 +1349,7 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { if d.decRefNoCaching() == 0 { - d.fs.renameMu.Lock() - d.checkCachingLocked(ctx) - d.fs.renameMu.Unlock() + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } } @@ -1374,15 +1409,16 @@ func (d *dentry) Watches() *vfs.Watches { // // If no watches are left on this dentry and it has no references, cache it. func (d *dentry) OnZeroWatches(ctx context.Context) { - if atomic.LoadInt64(&d.refs) == 0 { - d.fs.renameMu.Lock() - d.checkCachingLocked(ctx) - d.fs.renameMu.Unlock() - } + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } -// checkCachingLocked should be called after d's reference count becomes 0 or it -// becomes disowned. +// checkCachingLocked should be called after d's reference count becomes 0 or +// it becomes disowned. +// +// For performance, checkCachingLocked can also be called after d's reference +// count becomes non-zero, so that d can be removed from the LRU cache. This +// may help in reducing the size of the cache and hence reduce evictions. Note +// that this is not necessary for correctness. // // It may be called on a destroyed dentry. For example, // renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times @@ -1390,33 +1426,46 @@ func (d *dentry) OnZeroWatches(ctx context.Context) { // operation. One of the calls may destroy the dentry, so subsequent calls will // do nothing. // -// Preconditions: d.fs.renameMu must be locked for writing; it may be -// temporarily unlocked. -func (d *dentry) checkCachingLocked(ctx context.Context) { - // Dentries with a non-zero reference count must be retained. (The only way - // to obtain a reference on a dentry with zero references is via path - // resolution, which requires renameMu, so if d.refs is zero then it will - // remain zero while we hold renameMu for writing.) +// Preconditions: d.fs.renameMu must be locked for writing if +// renameMuWriteLocked is true; it may be temporarily unlocked. +func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked bool) { + d.cachingMu.Lock() refs := atomic.LoadInt64(&d.refs) if refs == -1 { // Dentry has already been destroyed. + d.cachingMu.Unlock() return } if refs > 0 { - // This isn't strictly necessary (fs.cachedDentries is permitted to - // contain dentries with non-zero refs, which are skipped by - // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but - // since we are already holding fs.renameMu for writing we may as well. + // fs.cachedDentries is permitted to contain dentries with non-zero refs, + // which are skipped by fs.evictCachedDentryLocked() upon reaching the end + // of the LRU. But it is still beneficial to remove d from the cache as we + // are already holding d.cachingMu. Keeping a cleaner cache also reduces + // the number of evictions (which is expensive as it acquires fs.renameMu). d.removeFromCacheLocked() + d.cachingMu.Unlock() return } // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { + d.removeFromCacheLocked() + d.cachingMu.Unlock() + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu for writing as needed by d.destroyLocked(). + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + // Now that renameMu is locked for writing, no more refs can be taken on + // d because path resolution requires renameMu for reading at least. + if atomic.LoadInt64(&d.refs) != 0 { + // Destroy d only if its ref is still 0. If not, either someone took a + // ref on it or it got destroyed before fs.renameMu could be acquired. + return + } + } if d.isDeleted() { d.watches.HandleDeletion(ctx) } - d.removeFromCacheLocked() d.destroyLocked(ctx) return } @@ -1426,24 +1475,36 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { // d.watches cannot concurrently transition from zero to non-zero, because // adding a watch requires holding a reference on d. if d.watches.Size() > 0 { - // As in the refs > 0 case, this is not strictly necessary. + // As in the refs > 0 case, removing d is beneficial. d.removeFromCacheLocked() + d.cachingMu.Unlock() return } if atomic.LoadInt32(&d.fs.released) != 0 { + d.cachingMu.Unlock() + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu to access d.parent. Lock it for writing as + // needed by d.destroyLocked() later. + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + } if d.parent != nil { d.parent.dirMu.Lock() delete(d.parent.children, d.name) d.parent.dirMu.Unlock() } d.destroyLocked(ctx) + return } + d.fs.cacheMu.Lock() // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) d.fs.cachedDentries.PushFront(d) + d.fs.cacheMu.Unlock() + d.cachingMu.Unlock() return } // Cache the dentry, then evict the least recently used cached dentry if @@ -1451,18 +1512,28 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { d.fs.cachedDentries.PushFront(d) d.fs.cachedDentriesLen++ d.cached = true - if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { + shouldEvict := d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries + d.fs.cacheMu.Unlock() + d.cachingMu.Unlock() + + if shouldEvict { + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu for writing as needed by + // d.evictCachedDentryLocked(). + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + } d.fs.evictCachedDentryLocked(ctx) - // Whether or not victim was destroyed, we brought fs.cachedDentriesLen - // back down to fs.opts.maxCachedDentries, so we don't loop. } } -// Preconditions: d.fs.renameMu must be locked for writing. +// Preconditions: d.cachingMu must be locked. func (d *dentry) removeFromCacheLocked() { if d.cached { + d.fs.cacheMu.Lock() d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- + d.fs.cacheMu.Unlock() d.cached = false } } @@ -1477,28 +1548,43 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { // Preconditions: // * fs.renameMu must be locked for writing; it may be temporarily unlocked. -// * fs.cachedDentriesLen != 0. func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + fs.cacheMu.Lock() victim := fs.cachedDentries.Back() + fs.cacheMu.Unlock() + if victim == nil { + // fs.cachedDentries may have become empty between when it was checked and + // when we locked fs.cacheMu. + return + } + + victim.cachingMu.Lock() victim.removeFromCacheLocked() // victim.refs or victim.watches.Size() may have become non-zero from an // earlier path resolution since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 { - if victim.parent != nil { - victim.parent.dirMu.Lock() - if !victim.vfsd.IsDead() { - // Note that victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount points. - fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) - delete(victim.parent.children, victim.name) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victim.parent.dirMu.Unlock() + if atomic.LoadInt64(&victim.refs) != 0 || victim.watches.Size() != 0 { + victim.cachingMu.Unlock() + return + } + if victim.parent != nil { + victim.parent.dirMu.Lock() + if !victim.vfsd.IsDead() { + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + // We're only deleting the dentry, not the file it + // represents, so we don't need to update + // victimParent.dirents etc. } - victim.destroyLocked(ctx) + victim.parent.dirMu.Unlock() } + // Safe to unlock cachingMu now that victim.vfsd.IsDead(). Henceforth any + // concurrent caching attempts on victim will attempt to destroy it and so + // will try to acquire fs.renameMu (which we have already acquired). Hence, + // fs.renameMu will synchronize the destroy attempts. + victim.cachingMu.Unlock() + victim.destroyLocked(ctx) } // destroyLocked destroys the dentry. @@ -1584,7 +1670,7 @@ func (d *dentry) destroyLocked(ctx context.Context) { // Drop the reference held by d on its parent without recursively locking // d.fs.renameMu. if d.parent != nil && d.parent.decRefNoCaching() == 0 { - d.parent.checkCachingLocked(ctx) + d.parent.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } refsvfs2.Unregister(d) } diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index 76f08e252..806392d50 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -55,7 +55,7 @@ func TestDestroyIdempotent(t *testing.T) { fs.renameMu.Lock() defer fs.renameMu.Unlock() - child.checkCachingLocked(ctx) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) if got := atomic.LoadInt64(&child.refs); got != -1 { t.Fatalf("child.refs=%d, want: -1", got) } @@ -63,6 +63,6 @@ func TestDestroyIdempotent(t *testing.T) { if got := atomic.LoadInt64(&parent.refs); got != -1 { t.Fatalf("parent.refs=%d, want: -1", got) } - child.checkCachingLocked(ctx) - child.checkCachingLocked(ctx) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 4f1ad0c88..f0e7bbaf7 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" @@ -203,18 +204,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } d := fd.dentry() + + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + // If the fd was opened with O_APPEND, make sure the file size is updated. // There is a possible race here if size is modified externally after // metadata cache is updated. if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { - if err := d.updateFromGetattr(ctx); err != nil { + if err := d.refreshSizeLocked(ctx); err != nil { return 0, offset, err } } - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - // Set offset to file size if the fd was opened with O_APPEND. if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { // Holding d.metadataMu is sufficient for reading d.size. @@ -291,8 +293,8 @@ func (fd *regularFileFD) writeCache(ctx context.Context, d *dentry, offset int64 } // Remove touched pages from the cache. - pgstart := usermem.PageRoundDown(uint64(offset)) - pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes())) + pgstart := hostarch.PageRoundDown(uint64(offset)) + pgend, ok := hostarch.PageRoundUp(uint64(offset + src.NumBytes())) if !ok { return syserror.EINVAL } @@ -408,7 +410,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) switch { case seg.Ok(): // Get internal mappings from the cache. - ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) + ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Read) if err != nil { dataMuUnlock() rw.d.handleMu.RUnlock() @@ -434,9 +436,9 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) if fillCache { // Read into the cache, then re-enter the loop to read from the // cache. - gapEnd, _ := usermem.PageRoundUp(gapMR.End) + gapEnd, _ := hostarch.PageRoundUp(gapMR.End) reqMR := memmap.MappableRange{ - Start: usermem.PageRoundDown(gapMR.Start), + Start: hostarch.PageRoundDown(gapMR.Start), End: gapEnd, } optMR := gap.Range() @@ -527,7 +529,7 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro case seg.Ok(): // Get internal mappings from the cache. segMR := seg.Range().Intersect(mr) - ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write) + ims, err := mf.MapInternal(seg.FileRangeOf(segMR), hostarch.Write) if err != nil { retErr = err goto exitLoop @@ -700,6 +702,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt } // After this point, d may be used as a memmap.Mappable. d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init) + opts.SentryOwnedContent = d.fs.opts.forcePageCache return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts) } @@ -714,7 +717,7 @@ func (d *dentry) mayCachePages() bool { } // AddMapping implements memmap.Mappable.AddMapping. -func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { d.mapsMu.Lock() mapped := d.mappings.AddMapping(ms, ar, offset, writable) // Do this unconditionally since whether we have a host FD can change @@ -735,7 +738,7 @@ func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar user } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { d.mapsMu.Lock() unmapped := d.mappings.RemoveMapping(ms, ar, offset, writable) for _, r := range unmapped { @@ -759,12 +762,12 @@ func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar u } // CopyMapping implements memmap.Mappable.CopyMapping. -func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { return d.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { d.handleMu.RLock() if d.mmapFD >= 0 && !d.fs.opts.forcePageCache { d.handleMu.RUnlock() @@ -777,7 +780,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab Source: mr, File: &d.pf, Offset: mr.Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }, }, nil } @@ -786,7 +789,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab // Constrain translations to d.size (rounded up) to prevent translation to // pages that may be concurrently truncated. - pgend, _ := usermem.PageRoundUp(d.size) + pgend, _ := hostarch.PageRoundUp(d.size) var beyondEOF bool if required.End > pgend { if required.Start >= pgend { @@ -811,7 +814,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab segMR := seg.Range().Intersect(optional) // TODO(jamieliu): Make Translations writable even if writability is // not required if already kept-dirty by another writable translation. - perms := usermem.AccessType{ + perms := hostarch.AccessType{ Read: true, Execute: true, } @@ -954,7 +957,7 @@ func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { } // MapInternal implements memmap.File.MapInternal. -func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() defer d.handleMu.RUnlock() return d.hostFileMapper.MapInternal(fr, int(d.mmapFD), at.Write) diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go index c90071e4e..83e841a51 100644 --- a/pkg/sentry/fsimpl/gofer/save_restore.go +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -22,12 +22,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) type saveRestoreContextID int @@ -85,7 +85,7 @@ func (fs *filesystem) PrepareSave(ctx context.Context) error { func (fd *specialFileFD) savePipeData(ctx context.Context) error { fd.bufMu.Lock() defer fd.bufMu.Unlock() - var buf [usermem.PageSize]byte + var buf [hostarch.PageSize]byte for { n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0)) if n != 0 { diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 4ae9d6d5e..b94dfeb7f 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/context", "//pkg/fdnotifier", "//pkg/fspath", + "//pkg/hostarch", "//pkg/iovec", "//pkg/log", "//pkg/marshal/primitive", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index b9cce4181..a81f550b1 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" @@ -431,8 +432,8 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } oldSize := uint64(hostStat.Size) if s.Size < oldSize { - oldpgend, _ := usermem.PageRoundUp(oldSize) - newpgend, _ := usermem.PageRoundUp(s.Size) + oldpgend, _ := hostarch.PageRoundUp(oldSize) + newpgend, _ := hostarch.PageRoundUp(s.Size) if oldpgend != newpgend { i.CachedMappable.InvalidateRange(memmap.MappableRange{newpgend, oldpgend}) } @@ -459,6 +460,9 @@ func (i *inode) DecRef(ctx context.Context) { if err := unix.Close(i.hostFD); err != nil { log.Warningf("failed to close host fd %d: %v", i.hostFD, err) } + // We can't rely on fdnotifier when closing the fd, because the event may race + // with fdnotifier.RemoveFD. Instead, notify the queue explicitly. + i.queue.Notify(waiter.EventHUp | waiter.ReadableEvents | waiter.WritableEvents) }) } diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go index 5688bddc8..c502d8e99 100644 --- a/pkg/sentry/fsimpl/host/save_restore.go +++ b/pkg/sentry/fsimpl/host/save_restore.go @@ -21,9 +21,9 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostfd" - "gvisor.dev/gvisor/pkg/usermem" ) // beforeSave is invoked by stateify. @@ -38,7 +38,7 @@ func (i *inode) beforeSave() { // EBADF from the read. i.bufMu.Lock() defer i.bufMu.Unlock() - var buf [usermem.PageSize]byte + var buf [hostarch.PageSize]byte for { n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */) if n != 0 { @@ -68,3 +68,10 @@ func (i *inode) afterLoad() { } } } + +// afterLoad is invoked by stateify. +func (c *ConnectedEndpoint) afterLoad() { + if err := c.initFromOptions(); err != nil { + panic(fmt.Sprintf("initFromOptions failed: %v", err)) + } +} diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 60e237ac7..ca85f5601 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -39,7 +39,7 @@ import ( func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) { // Set up an external transport.Endpoint using the host fd. addr := fmt.Sprintf("hostfd:[%d]", hostFD) - e, err := NewConnectedEndpoint(ctx, hostFD, addr, true /* saveable */) + e, err := NewConnectedEndpoint(hostFD, addr) if err != nil { return nil, err.ToError() } @@ -86,7 +86,10 @@ type ConnectedEndpoint struct { // for restoring them. func (c *ConnectedEndpoint) init() *syserr.Error { c.InitRefs() + return c.initFromOptions() +} +func (c *ConnectedEndpoint) initFromOptions() *syserr.Error { family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN) if err != nil { return syserr.FromError(err) @@ -123,7 +126,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { // The caller is responsible for calling Init(). Additionaly, Release needs to // be called twice because ConnectedEndpoint is both a transport.Receiver and // transport.ConnectedEndpoint. -func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) { +func NewConnectedEndpoint(hostFD int, addr string) (*ConnectedEndpoint, *syserr.Error) { e := ConnectedEndpoint{ fd: hostFD, addr: addr, @@ -330,8 +333,16 @@ func (c *ConnectedEndpoint) CloseUnread() {} // SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { - // gVisor does not permit setting of SO_SNDBUF for host backed unix domain - // sockets. + // gVisor does not permit setting of SO_SNDBUF for host backed unix + // domain sockets. + return atomic.LoadInt64(&c.sndbuf) +} + +// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize. +func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_RCVBUF for host backed unix + // domain sockets. Receive buffer does not have any effect for unix + // sockets and we claim to be the same as send buffer. return atomic.LoadInt64(&c.sndbuf) } diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 6dbc7e34d..b7d13cced 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -105,6 +105,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/hostarch", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 65054b0ea..84b1c3745 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -25,8 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// DynamicBytesFile implements kernfs.Inode and represents a read-only -// file whose contents are backed by a vfs.DynamicBytesSource. +// DynamicBytesFile implements kernfs.Inode and represents a read-only file +// whose contents are backed by a vfs.DynamicBytesSource. If data additionally +// implements vfs.WritableDynamicBytesSource, the file also supports dispatching +// writes to the implementer, but note that this will not update the source data. // // Must be instantiated with NewDynamicBytesFile or initialized with Init // before first use. @@ -40,7 +42,9 @@ type DynamicBytesFile struct { InodeNotSymlink locks vfs.FileLocks - data vfs.DynamicBytesSource + // data can additionally implement vfs.WritableDynamicBytesSource to support + // writes. + data vfs.DynamicBytesSource } var _ Inode = (*DynamicBytesFile)(nil) diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 6b890a39c..3d0866ecf 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -20,12 +20,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // InodeNoopRefCount partially implements the Inode interface, specifically the @@ -206,7 +206,7 @@ func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) atomic.StoreUint32(&a.nlink, nlink) - atomic.StoreUint32(&a.blockSize, usermem.PageSize) + atomic.StoreUint32(&a.blockSize, hostarch.PageSize) now := ktime.NowFromContext(ctx).Nanoseconds() atomic.StoreInt64(&a.atime, now) atomic.StoreInt64(&a.mtime, now) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 565d723f0..16486eeae 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -61,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -508,6 +509,15 @@ func (d *Dentry) Inode() Inode { return d.inode } +// FSLocalPath returns an absolute path to d, relative to the root of its +// filesystem. +func (d *Dentry) FSLocalPath() string { + var b fspath.Builder + _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b) + b.PrependByte('/') + return b.String() +} + // The Inode interface maps filesystem-level operations that operate on paths to // equivalent operations on specific filesystem nodes. // diff --git a/pkg/sentry/fsimpl/kernfs/mmap_util.go b/pkg/sentry/fsimpl/kernfs/mmap_util.go index bd6a134b4..d1539d904 100644 --- a/pkg/sentry/fsimpl/kernfs/mmap_util.go +++ b/pkg/sentry/fsimpl/kernfs/mmap_util.go @@ -16,11 +16,11 @@ package kernfs import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // inodePlatformFile implements memmap.File. It exists solely because inode @@ -66,7 +66,7 @@ func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { } // MapInternal implements memmap.File.MapInternal. -func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } @@ -100,7 +100,7 @@ func (i *CachedMappable) Init(hostFD int) { } // AddMapping implements memmap.Mappable.AddMapping. -func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { i.mapsMu.Lock() mapped := i.mappings.AddMapping(ms, ar, offset, writable) for _, r := range mapped { @@ -111,7 +111,7 @@ func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { i.mapsMu.Lock() unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable) for _, r := range unmapped { @@ -121,19 +121,19 @@ func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpa } // CopyMapping implements memmap.Mappable.CopyMapping. -func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { return i.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { mr := optional return []memmap.Translation{ { Source: mr, File: &i.pf, Offset: mr.Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }, }, nil } diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index bf13bbbf4..5504476c8 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/hostarch", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 27b00cf6f..45aa5a494 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -21,11 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) func (d *dentry) isCopiedUp() bool { @@ -138,8 +138,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { // We may have memory mappings of the file on the lower layer. // Switch to mapping the file on the upper layer instead. mmapOpts = &memmap.MMapOpts{ - Perms: usermem.ReadWrite, - MaxPerms: usermem.ReadWrite, + Perms: hostarch.ReadWrite, + MaxPerms: hostarch.ReadWrite, } if err := newFD.ConfigureMMap(ctx, mmapOpts); err != nil { cleanupUndoCopyUp() diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index d791c06db..43bfd69a3 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -445,7 +446,7 @@ func (fd *regularFileFD) ensureMappable(ctx context.Context, opts *memmap.MMapOp } // AddMapping implements memmap.Mappable.AddMapping. -func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { d.mapsMu.Lock() defer d.mapsMu.Unlock() if err := d.wrappedMappable.AddMapping(ctx, ms, ar, offset, writable); err != nil { @@ -458,7 +459,7 @@ func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar user } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { d.mapsMu.Lock() defer d.mapsMu.Unlock() d.wrappedMappable.RemoveMapping(ctx, ms, ar, offset, writable) @@ -468,7 +469,7 @@ func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar u } // CopyMapping implements memmap.Mappable.CopyMapping. -func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { d.mapsMu.Lock() defer d.mapsMu.Unlock() if err := d.wrappedMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil { @@ -481,7 +482,7 @@ func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, } // Translate implements memmap.Mappable.Translate. -func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { d.dataMu.RLock() defer d.dataMu.RUnlock() return d.wrappedMappable.Translate(ctx, required, optional, at) diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD index 5950a2d59..278ee3c92 100644 --- a/pkg/sentry/fsimpl/pipefs/BUILD +++ b/pkg/sentry/fsimpl/pipefs/BUILD @@ -10,12 +10,12 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/hostarch", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/vfs", "//pkg/syserror", - "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 3f05e444e..08aedc2ad 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -22,13 +22,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // +stateify savable @@ -131,7 +131,7 @@ func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOpti ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds()) return linux.Statx{ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, - Blksize: usermem.PageSize, + Blksize: hostarch.PageSize, Nlink: 1, UID: uint32(i.uid), GID: uint32(i.gid), diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index d47a4fff9..2b628bd55 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -81,6 +81,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 254a8b062..ce8f55b1f 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -86,13 +86,13 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) - var cgroups map[string]string + var fakeCgroupControllers map[string]string if opts.InternalData != nil { data := opts.InternalData.(*InternalData) - cgroups = data.Cgroups + fakeCgroupControllers = data.Cgroups } - inode := procfs.newTasksInode(ctx, k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, fakeCgroupControllers) var dentry kernfs.Dentry dentry.InitRoot(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index fea138f93..d05cc1508 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -47,7 +47,7 @@ type taskInode struct { var _ kernfs.Inode = (*taskInode)(nil) -func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) { +func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, fakeCgroupControllers map[string]string) (kernfs.Inode, error) { if task.ExitState() == kernel.TaskExitDead { return nil, syserror.ESRCH } @@ -82,10 +82,12 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns "uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { - contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers) + contents["task"] = fs.newSubtasks(ctx, task, pidns, fakeCgroupControllers) } - if len(cgroupControllers) > 0 { - contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) + if len(fakeCgroupControllers) > 0 { + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newFakeCgroupData(fakeCgroupControllers)) + } else { + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskCgroupData{task: task}) } taskInode := &taskInode{task: task} @@ -226,11 +228,14 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData { return &ioData{ioUsage: t} } -// newCgroupData creates inode that shows cgroup information. -// From man 7 cgroups: "For each cgroup hierarchy of which the process is a -// member, there is one entry containing three colon-separated fields: -// hierarchy-ID:controller-list:cgroup-path" -func newCgroupData(controllers map[string]string) dynamicInode { +// newFakeCgroupData creates an inode that shows fake cgroup +// information passed in as mount options. From man 7 cgroups: "For +// each cgroup hierarchy of which the process is a member, there is +// one entry containing three colon-separated fields: +// hierarchy-ID:controller-list:cgroup-path" +// +// TODO(b/182488796): Remove once all users adopt cgroupfs. +func newFakeCgroupData(controllers map[string]string) dynamicInode { var buf bytes.Buffer // The hierarchy ids must be positive integers (for cgroup v1), but the diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index fdae163d1..b294dfd6a 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" @@ -122,8 +123,8 @@ func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error { buf.Grow((len(auxv) + 1) * 16) for _, e := range auxv { var tmp [16]byte - usermem.ByteOrder.PutUint64(tmp[:8], e.Key) - usermem.ByteOrder.PutUint64(tmp[8:], uint64(e.Value)) + hostarch.ByteOrder.PutUint64(tmp[:8], e.Key) + hostarch.ByteOrder.PutUint64(tmp[8:], uint64(e.Value)) buf.Write(tmp[:]) } var atNull [16]byte @@ -168,15 +169,15 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { defer m.DecUsers(ctx) // Figure out the bounds of the exec arg we are trying to read. - var ar usermem.AddrRange + var ar hostarch.AddrRange switch d.arg { case cmdlineDataArg: - ar = usermem.AddrRange{ + ar = hostarch.AddrRange{ Start: m.ArgvStart(), End: m.ArgvEnd(), } case environDataArg: - ar = usermem.AddrRange{ + ar = hostarch.AddrRange{ Start: m.EnvvStart(), End: m.EnvvEnd(), } @@ -192,7 +193,7 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { // until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading // cmdline and environment"). writer := &bufferWriter{buf: buf} - if n, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil { + if n, err := m.CopyInTo(ctx, hostarch.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil { // Nothing to copy or something went wrong. return err } @@ -209,7 +210,7 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // There is no NULL terminator in the string, return into envp. - arEnvv := usermem.AddrRange{ + arEnvv := hostarch.AddrRange{ Start: m.EnvvStart(), End: m.EnvvEnd(), } @@ -218,11 +219,11 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208 // we'll return one page total between argv and envp because of the // above page restrictions. - if buf.Len() >= usermem.PageSize { + if buf.Len() >= hostarch.PageSize { // Returned at least one page already, nothing else to add. return nil } - remaining := usermem.PageSize - buf.Len() + remaining := hostarch.PageSize - buf.Len() if int(arEnvv.Length()) > remaining { end, ok := arEnvv.Start.AddLength(uint64(remaining)) if !ok { @@ -230,7 +231,7 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { } arEnvv.End = end } - if _, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil { + if _, err := m.CopyInTo(ctx, hostarch.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil { return err } @@ -323,7 +324,7 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in // the system page size, and the write must be performed at the start of // the file ..." - user_namespaces(7) srclen := src.NumBytes() - if srclen >= usermem.PageSize || offset != 0 { + if srclen >= hostarch.PageSize || offset != 0 { return 0, syserror.EINVAL } b := make([]byte, srclen) @@ -481,7 +482,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64 defer m.DecUsers(ctx) // Buffer the read data because of MM locks buf := make([]byte, dst.NumBytes()) - n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + n, readErr := m.CopyIn(ctx, hostarch.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) if n > 0 { if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { return 0, syserror.EFAULT @@ -613,7 +614,7 @@ func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { rss = mm.ResidentSetSize() } }) - fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize) + fmt.Fprintf(buf, "%d %d ", vss, rss/hostarch.PageSize) // rsslim. fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Limits().Get(limits.Rss).Cur) @@ -655,7 +656,7 @@ func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { } }) - fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize) + fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize) return nil } @@ -774,7 +775,7 @@ func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset } // Limit input size so as not to impact performance if input size is large. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) @@ -1099,3 +1100,32 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release(ctx context.Context) { fd.inode.DecRef(ctx) } + +// taskCgroupData generates data for /proc/[pid]/cgroup. +// +// +stateify savable +type taskCgroupData struct { + dynamicBytesFileSetAttr + task *kernel.Task +} + +var _ dynamicInode = (*taskCgroupData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *taskCgroupData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // When a task is existing on Linux, a task's cgroup set is cleared and + // reset to the initial cgroup set, which is essentially the set of root + // cgroups. Because of this, the /proc/<pid>/cgroup file is always readable + // on Linux throughout a task's lifetime. + // + // The sentry removes tasks from cgroups during the exit process, but + // doesn't move them into an initial cgroup set, so partway through task + // exit this file show a task is in no cgroups, which is incorrect. Instead, + // once a task has left its cgroups, we return an error. + if d.task.ExitState() >= kernel.TaskExitInitiated { + return syserror.ESRCH + } + + d.task.GenerateProcTaskCgroup(buf) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index d4f6a5a9b..177cb828f 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -34,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/usermem" ) func (fs *filesystem) newTaskNetDir(ctx context.Context, task *kernel.Task) kernfs.Inode { @@ -295,7 +295,7 @@ func networkToHost16(n uint16) uint16 { // binary.BigEndian.Uint16() require a read of binary.BigEndian and an // interface method call, defeating inlining. buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)} - return usermem.ByteOrder.Uint16(buf[:]) + return hostarch.ByteOrder.Uint16(buf[:]) } func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { @@ -317,14 +317,14 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { // __be32 which is a typedef for an unsigned int, and is printed with // %X. This means that for a little-endian machine, Linux prints the // least-significant byte of the address first. To emulate this, we first - // invert the byte order for the address using usermem.ByteOrder.Uint32, + // invert the byte order for the address using hostarch.ByteOrder.Uint32, // which makes it have the equivalent encoding to a __be32 on a little // endian machine. Note that this operation is a no-op on a big endian // machine. Then similar to Linux, we format it with %X, which will print // the most-significant byte of the __be32 address first, which is now // actually the least-significant byte of the original address in // linux.SockAddrInet.Addr on little endian machines, due to the conversion. - addr := usermem.ByteOrder.Uint32(a.Addr[:]) + addr := hostarch.ByteOrder.Uint32(a.Addr[:]) fmt.Fprintf(w, "%08X:%04X ", addr, port) case linux.AF_INET6: @@ -334,10 +334,10 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { } port := networkToHost16(a.Port) - addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4]) - addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8]) - addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12]) - addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16]) + addr0 := hostarch.ByteOrder.Uint32(a.Addr[0:4]) + addr1 := hostarch.ByteOrder.Uint32(a.Addr[4:8]) + addr2 := hostarch.ByteOrder.Uint32(a.Addr[8:12]) + addr3 := hostarch.ByteOrder.Uint32(a.Addr[12:16]) fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port) } } @@ -739,10 +739,10 @@ func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error { ) if len(rt.GatewayAddr) == header.IPv4AddressSize { flags |= linux.RTF_GATEWAY - gw = usermem.ByteOrder.Uint32(rt.GatewayAddr) + gw = hostarch.ByteOrder.Uint32(rt.GatewayAddr) } if len(rt.DstAddr) == header.IPv4AddressSize { - prefix = usermem.ByteOrder.Uint32(rt.DstAddr) + prefix = hostarch.ByteOrder.Uint32(rt.DstAddr) } l := fmt.Sprintf( "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d", diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index fdc580610..7c7543f14 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -54,15 +54,15 @@ type tasksInode struct { // '/proc/self' and '/proc/thread-self' have custom directory offsets in // Linux. So handle them outside of OrderedChildren. - // cgroupControllers is a map of controller name to directory in the + // fakeCgroupControllers is a map of controller name to directory in the // cgroup hierarchy. These controllers are immutable and will be listed // in /proc/pid/cgroup if not nil. - cgroupControllers map[string]string + fakeCgroupControllers map[string]string } var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), @@ -76,11 +76,16 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), "version": fs.newInode(ctx, root, 0444, &versionData{}), } + // If fakeCgroupControllers are provided, don't create a cgroupfs backed + // /proc/cgroup as it will not match the fake controllers. + if len(fakeCgroupControllers) == 0 { + contents["cgroups"] = fs.newInode(ctx, root, 0444, &cgroupsData{}) + } inode := &tasksInode{ - pidns: pidns, - fs: fs, - cgroupControllers: cgroupControllers, + pidns: pidns, + fs: fs, + fakeCgroupControllers: fakeCgroupControllers, } inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.InitRefs() @@ -118,7 +123,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err return nil, syserror.ENOENT } - return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers) + return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers) } // IterDirents implements kernfs.inodeDirectory.IterDirents. diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 01b7a6678..e1a8b4409 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -28,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // +stateify savable @@ -270,7 +270,7 @@ func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { anon := snapshot.Anonymous + snapshot.Tmpfs file := snapshot.PageCache + snapshot.Mapped // We don't actually have active/inactive LRUs, so just make up numbers. - activeFile := (file / 2) &^ (usermem.PageSize - 1) + activeFile := (file / 2) &^ (hostarch.PageSize - 1) inactiveFile := file - activeFile fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024) @@ -384,3 +384,19 @@ func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error k.VFS().GenerateProcFilesystems(buf) return nil } + +// cgroupsData backs /proc/cgroups. +// +// +stateify savable +type cgroupsData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cgroupsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + r := kernel.KernelFromContext(ctx).CgroupRegistry() + r.GenerateProcCgroups(buf) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index fb274b78e..9b14dd6b9 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -214,7 +215,7 @@ func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset } // Limit the amount of memory allocated. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) @@ -262,7 +263,7 @@ func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, off } // Limit the amount of memory allocated. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) @@ -318,7 +319,7 @@ func (d *tcpMemData) Write(ctx context.Context, src usermem.IOSequence, offset i defer d.mu.Unlock() // Limit the amount of memory allocated. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) size, err := d.readSizeLocked() if err != nil { return 0, err @@ -406,7 +407,7 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs } // Limit input size so as not to impact performance if input size is large. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) @@ -463,7 +464,7 @@ func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset i // Limit input size so as not to impact performance if input size is // large. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) ports := make([]int32, 2) n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts) diff --git a/pkg/sentry/fsimpl/proc/yama.go b/pkg/sentry/fsimpl/proc/yama.go index aebfe8944..e039ec45e 100644 --- a/pkg/sentry/fsimpl/proc/yama.go +++ b/pkg/sentry/fsimpl/proc/yama.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -62,7 +63,7 @@ func (s *yamaPtraceScope) Write(ctx context.Context, src usermem.IOSequence, off } // Limit the amount of memory allocated. - src = src.TakeFirst(usermem.PageSize - 1) + src = src.TakeFirst(hostarch.PageSize - 1) var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1d9280dae..14eb10dcd 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -122,11 +122,11 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs } func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode { - // If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs - // should be mounted at debug/, but for our purposes, it is sufficient to - // keep it in sys. + // Set up /sys/kernel/debug/kcov. Technically, debugfs should be + // mounted at debug/, but for our purposes, it is sufficient to keep it + // in sys. var children map[string]kernfs.Inode - if coverage.KcovAvailable() { + if coverage.KcovSupported() { log.Debugf("Set up /sys/kernel/debug/kcov") children = map[string]kernfs.Inode{ "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{ diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 400a97996..b3f9d1010 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -15,6 +15,7 @@ go_library( "//pkg/context", "//pkg/cpuid", "//pkg/fspath", + "//pkg/hostarch", "//pkg/memutil", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/tmpfs", diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go index 1a8525b06..59e6f9c92 100644 --- a/pkg/sentry/fsimpl/testutil/testutil.go +++ b/pkg/sentry/fsimpl/testutil/testutil.go @@ -30,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // System represents the context for a single test. @@ -105,7 +107,7 @@ func (s *System) Destroy() { // ReadToEnd reads the contents of fd until EOF to a string. func (s *System) ReadToEnd(fd *vfs.FileDescription) (string, error) { - buf := make([]byte, usermem.PageSize) + buf := make([]byte, hostarch.PageSize) bufIOSeq := usermem.BytesIOSequence(buf) opts := vfs.ReadOptions{} diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD index fbb02a271..7ce7dc429 100644 --- a/pkg/sentry/fsimpl/timerfd/BUILD +++ b/pkg/sentry/fsimpl/timerfd/BUILD @@ -8,6 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/kernel/time", "//pkg/sentry/vfs", "//pkg/syserror", diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go index 64d33c3a8..cbb8b67c5 100644 --- a/pkg/sentry/fsimpl/timerfd/timerfd.go +++ b/pkg/sentry/fsimpl/timerfd/timerfd.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -72,7 +73,7 @@ func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequenc } if val := atomic.SwapUint64(&tfd.val, 0); val != 0 { var buf [sizeofUint64]byte - usermem.ByteOrder.PutUint64(buf[:], val) + hostarch.ByteOrder.PutUint64(buf[:], val) if _, err := dst.CopyOut(ctx, buf[:]); err != nil { // Linux does not undo consuming the number of // expirations even if writing to userspace fails. diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 09957c2b7..e21fddd7f 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -59,6 +59,7 @@ go_library( "//pkg/amutex", "//pkg/context", "//pkg/fspath", + "//pkg/hostarch", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index a6d161882..c45bddff6 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -224,7 +225,7 @@ func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) { } // AddMapping implements memmap.Mappable.AddMapping. -func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { rf.mapsMu.Lock() defer rf.mapsMu.Unlock() rf.dataMu.RLock() @@ -240,7 +241,7 @@ func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, a pagesBefore := rf.writableMappingPages // ar is guaranteed to be page aligned per memmap.Mappable. - rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize) + rf.writableMappingPages += uint64(ar.Length() / hostarch.PageSize) if rf.writableMappingPages < pagesBefore { panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages)) @@ -251,7 +252,7 @@ func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, a } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { rf.mapsMu.Lock() defer rf.mapsMu.Unlock() @@ -261,7 +262,7 @@ func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace pagesBefore := rf.writableMappingPages // ar is guaranteed to be page aligned per memmap.Mappable. - rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize) + rf.writableMappingPages -= uint64(ar.Length() / hostarch.PageSize) if rf.writableMappingPages > pagesBefore { panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages)) @@ -270,12 +271,12 @@ func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace } // CopyMapping implements memmap.Mappable.CopyMapping. -func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { return rf.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { rf.dataMu.Lock() defer rf.dataMu.Unlock() @@ -307,7 +308,7 @@ func (rf *regularFile) Translate(ctx context.Context, required, optional memmap. Source: segMR, File: rf.memFile, Offset: seg.FileRangeOf(segMR).Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }) translatedEnd = segMR.End } @@ -487,6 +488,7 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { file := fd.inode().impl.(*regularFile) + opts.SentryOwnedContent = true return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts) } @@ -539,7 +541,7 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er switch { case seg.Ok(): // Get internal mappings. - ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) + ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Read) if err != nil { return done, err } @@ -608,7 +610,7 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, // // See Linux, mm/filemap.c:generic_perform_write() and // mm/shmem.c:shmem_write_begin(). - if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart { + if pgstart := uint64(hostarch.Addr(rw.file.size).RoundDown()); end > pgstart { end = pgstart } if end <= rw.off { @@ -619,8 +621,8 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, // Page-aligned mr for when we need to allocate memory. RoundUp can't // overflow since end is an int64. - pgstartaddr := usermem.Addr(rw.off).RoundDown() - pgendaddr, _ := usermem.Addr(end).RoundUp() + pgstartaddr := hostarch.Addr(rw.off).RoundDown() + pgendaddr, _ := hostarch.Addr(end).RoundUp() pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)} var ( @@ -633,7 +635,7 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, switch { case seg.Ok(): // Get internal mappings. - ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write) + ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Write) if err != nil { retErr = err goto exitLoop diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 8df81f589..9ae25ce9e 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -36,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -43,7 +44,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Name is the default filesystem name. @@ -252,8 +252,8 @@ func (d *dentry) releaseChildrenLocked(ctx context.Context) { // immutable var globalStatfs = linux.Statfs{ Type: linux.TMPFS_MAGIC, - BlockSize: usermem.PageSize, - FragmentSize: usermem.PageSize, + BlockSize: hostarch.PageSize, + FragmentSize: hostarch.PageSize, NameLength: linux.NAME_MAX, // tmpfs currently does not support configurable size limits. In Linux, @@ -263,9 +263,9 @@ var globalStatfs = linux.Statfs{ // chosen to ensure that BlockSize * Blocks does not overflow int64 (which // applications may also handle incorrectly). // TODO(b/29637826): allow configuring a tmpfs size and enforce it. - Blocks: math.MaxInt64 / usermem.PageSize, - BlocksFree: math.MaxInt64 / usermem.PageSize, - BlocksAvailable: math.MaxInt64 / usermem.PageSize, + Blocks: math.MaxInt64 / hostarch.PageSize, + BlocksFree: math.MaxInt64 / hostarch.PageSize, + BlocksAvailable: math.MaxInt64 / hostarch.PageSize, } // dentry implements vfs.DentryImpl. @@ -485,7 +485,7 @@ func (i *inode) statTo(stat *linux.Statx) { linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_ATIME | linux.STATX_CTIME | linux.STATX_MTIME - stat.Blksize = usermem.PageSize + stat.Blksize = hostarch.PageSize stat.Nlink = atomic.LoadUint32(&i.nlink) stat.UID = atomic.LoadUint32(&i.uid) stat.GID = atomic.LoadUint32(&i.gid) diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index e265be0ee..d473a922d 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -14,13 +14,16 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/hostarch", "//pkg/marshal/primitive", "//pkg/merkletree", "//pkg/refsvfs2", + "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 6cb1a23e0..ca8090bbf 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -200,7 +200,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -209,7 +209,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's hash. @@ -223,12 +223,14 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err } + defer parentMerkleFD.DecRef(ctx) + // dataSize is the size of raw data for the Merkle tree. For a file, // dataSize is the size of the whole file. For a directory, dataSize is // the size of all its children's hashes. @@ -241,7 +243,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -251,7 +253,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := FileReadWriteSeeker{ @@ -264,7 +266,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi Start: parent.lowerVD, }, &vfs.StatOptions{}) if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -294,7 +296,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi }) parent.hashMu.RUnlock() if err != nil && err != io.EOF { - return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. @@ -331,19 +333,21 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { return err } + defer fd.DecRef(ctx) + merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ Name: merkleSizeXattr, Size: sizeOfStringInt32, }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return err @@ -351,7 +355,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry size, err := strconv.Atoi(merkleSize) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } if d.isDir() && len(d.childrenNames) == 0 { @@ -361,14 +365,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) } if err != nil { return err } childrenOffset, err := strconv.Atoi(childrenOffString) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) } childrenSizeString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ @@ -377,23 +381,23 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) } if err != nil { return err } childrenSize, err := strconv.Atoi(childrenSizeString) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) } childrenNames := make([]byte, childrenSize) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(childrenOffset), vfs.ReadOptions{}); err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err)) } if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err)) } } @@ -438,7 +442,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry } if _, err := merkletree.Verify(params); err != nil && err != io.EOF { - return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) } d.mode = uint32(stat.Mode) d.uid = stat.UID @@ -471,7 +475,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // The file was previously accessed. If the // file does not exist now, it indicates an // unexpected modification to the file system. - return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) } if err != nil { return nil, err @@ -483,7 +487,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // does not exist now, it indicates an unexpected // modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) } if err != nil { return nil, err @@ -553,8 +557,8 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } childVD, err := parent.getLowerAt(ctx, vfsObj, name) - if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) + if parent.verityEnabled() && err == syserror.ENOENT { + return nil, fs.alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) } if err != nil { return nil, err @@ -565,30 +569,31 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, defer childVD.DecRef(ctx) childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) - if err == syserror.ENOENT { - if !fs.allowRuntimeEnable { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) - } - childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(merklePrefix + name), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT, - Mode: 0644, - }) - if err != nil { - return nil, err - } - childMerkleFD.DecRef(ctx) - childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) - if err != nil { + if err != nil { + if err == syserror.ENOENT { + if parent.verityEnabled() { + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) + } + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(merklePrefix + name), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + if err != nil { + return nil, err + } + } else { return nil, err } } - if err != nil && err != syserror.ENOENT { - return nil, err - } // Clear the Merkle tree file if they are to be generated at runtime. // TODO(b/182315468): Optimize the Merkle tree generate process to @@ -632,8 +637,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childVD.IncRef() childMerkleVD.IncRef() - parent.IncRef() - child.parent = parent child.name = name child.mode = uint32(stat.Mode) @@ -657,6 +660,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } } + parent.IncRef() + child.parent = parent + return child, nil } @@ -855,7 +861,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // missing, it indicates an unexpected modification to the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err } @@ -878,7 +884,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -903,7 +909,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }) if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -921,7 +927,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err != nil { if err == syserror.ENOENT { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } return nil, err } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 0d9b0ee2c..458c7fcb6 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -34,6 +34,8 @@ package verity import ( + "bytes" + "encoding/hex" "encoding/json" "fmt" "math" @@ -44,13 +46,16 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -93,14 +98,18 @@ const ( ) var ( - // action specifies the action towards detected violation. - action ViolationAction - // verityMu synchronizes concurrent operations that enable verity and perform // verification checks. verityMu sync.RWMutex ) +// Mount option names for verityfs. +const ( + moptLowerPath = "lower_path" + moptRootHash = "root_hash" + moptRootName = "root_name" +) + // HashAlgorithm is a type specifying the algorithm used to hash the file // content. type HashAlgorithm int @@ -167,6 +176,12 @@ type filesystem struct { // system. alg HashAlgorithm + // action specifies the action towards detected violation. + action ViolationAction + + // opts is the string mount options passed to opts.Data. + opts string + // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. @@ -189,9 +204,6 @@ type filesystem struct { // // +stateify savable type InternalFilesystemOptions struct { - // RootMerkleFileName is the name of the verity root Merkle tree file. - RootMerkleFileName string - // LowerName is the name of the filesystem wrapped by verity fs. LowerName string @@ -199,9 +211,6 @@ type InternalFilesystemOptions struct { // system. Alg HashAlgorithm - // RootHash is the root hash of the overall verity file system. - RootHash []byte - // AllowRuntimeEnable specifies whether the verity file system allows // enabling verification for files (i.e. building Merkle trees) during // runtime. @@ -226,8 +235,8 @@ func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In ErrorOnViolation // mode, it returns EIO, otherwise it panic. -func alertIntegrityViolation(msg string) error { - if action == ErrorOnViolation { +func (fs *filesystem) alertIntegrityViolation(msg string) error { + if fs.action == ErrorOnViolation { return syserror.EIO } panic(msg) @@ -235,28 +244,99 @@ func alertIntegrityViolation(msg string) error { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + mopts := vfs.GenericParseMountOptions(opts.Data) + var rootHash []byte + if encodedRootHash, ok := mopts[moptRootHash]; ok { + delete(mopts, moptRootHash) + hash, err := hex.DecodeString(encodedRootHash) + if err != nil { + ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err) + return nil, nil, syserror.EINVAL + } + rootHash = hash + } + var lowerPathname string + if path, ok := mopts[moptLowerPath]; ok { + delete(mopts, moptLowerPath) + lowerPathname = path + } + rootName := "root" + if root, ok := mopts[moptRootName]; ok { + delete(mopts, moptRootName) + rootName = root + } + + // Check for unparsed options. + if len(mopts) != 0 { + ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + + // Handle internal options. iopts, ok := opts.InternalData.(InternalFilesystemOptions) - if !ok { + if len(lowerPathname) == 0 && !ok { ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } - action = iopts.Action - - // Mount the lower file system. The lower file system is wrapped inside - // verity, and should not be exposed or connected. - mopts := &vfs.MountOptions{ - GetFilesystemOptions: iopts.LowerGetFSOptions, - InternalMount: true, + if len(lowerPathname) != 0 { + if ok { + ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path") + return nil, nil, syserror.EINVAL + } + iopts = InternalFilesystemOptions{ + AllowRuntimeEnable: len(rootHash) == 0, + Action: ErrorOnViolation, + } } - mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts) - if err != nil { - return nil, nil, err + + var lowerMount *vfs.Mount + var mountedLowerVD vfs.VirtualDentry + // Use an existing mount if lowerPath is provided. + if len(lowerPathname) != 0 { + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + lowerPath := fspath.Parse(lowerPathname) + if !lowerPath.Absolute { + ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname) + return nil, nil, syserror.EINVAL + } + var err error + mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: lowerPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("verity.FilesystemType.GetFilesystem: failed to resolve lower_path %q: %v", lowerPathname, err) + return nil, nil, err + } + lowerMount = mountedLowerVD.Mount() + defer mountedLowerVD.DecRef(ctx) + } else { + // Mount the lower file system. The lower file system is wrapped inside + // verity, and should not be exposed or connected. + mountOpts := &vfs.MountOptions{ + GetFilesystemOptions: iopts.LowerGetFSOptions, + InternalMount: true, + } + mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mountOpts) + if err != nil { + return nil, nil, err + } + lowerMount = mnt } fs := &filesystem{ creds: creds.Fork(), alg: iopts.Alg, - lowerMount: mnt, + lowerMount: lowerMount, + action: iopts.Action, + opts: opts.Data, allowRuntimeEnable: iopts.AllowRuntimeEnable, } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -264,11 +344,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Construct the root dentry. d := fs.newDentry() d.refs = 1 - lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root()) + lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root()) lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + rootName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -309,7 +389,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // the root Merkle file, or it's never generated. fs.vfsfs.DecRef(ctx) d.DecRef(ctx) - return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") + return nil, nil, fs.alertIntegrityViolation("Failed to find root Merkle file") } // Clear the Merkle tree file if they are to be generated at runtime. @@ -348,9 +428,15 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID - d.hash = make([]byte, len(iopts.RootHash)) d.childrenNames = make(map[string]struct{}) + d.hashMu.Lock() + d.hash = make([]byte, len(rootHash)) + copy(d.hash, rootHash) + d.hashMu.Unlock() + + fs.rootDentry = d + if !d.isDir() { ctx.Warningf("verity root must be a directory") return nil, nil, syserror.EINVAL @@ -366,7 +452,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Size: sizeOfStringInt32, }) if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) } if err != nil { return nil, nil, err @@ -374,7 +460,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt off, err := strconv.Atoi(offString) if err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) } sizeString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{ @@ -385,14 +471,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Size: sizeOfStringInt32, }) if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) } if err != nil { return nil, nil, err } size, err := strconv.Atoi(sizeString) if err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) } lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ @@ -402,19 +488,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) } if err != nil { return nil, nil, err } + defer lowerMerkleFD.DecRef(ctx) + childrenNames := make([]byte, size) if _, err := lowerMerkleFD.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(off), vfs.ReadOptions{}); err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err)) } if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) } if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { @@ -422,13 +510,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } - d.hashMu.Lock() - copy(d.hash, iopts.RootHash) - d.hashMu.Unlock() d.vfsd.Init(d) - fs.rootDentry = d - return &fs.vfsfs, &d.vfsd, nil } @@ -439,7 +522,7 @@ func (fs *filesystem) Release(ctx context.Context) { // MountOptions implements vfs.FilesystemImpl.MountOptions. func (fs *filesystem) MountOptions() string { - return "" + return fs.opts } // dentry implements vfs.DentryImpl. @@ -720,6 +803,10 @@ type fileDescription struct { // underlying file system. lowerFD *vfs.FileDescription + // lowerMappable is the memmap.Mappable corresponding to this file in the + // underlying file system. + lowerMappable memmap.Mappable + // merkleReader is the read-only FileDescription corresponding to the // Merkle tree file in the underlying file system. merkleReader *vfs.FileDescription @@ -792,7 +879,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa // Verify that the child is expected. if dirent.Name != "." && dirent.Name != ".." { if _, ok := fd.d.childrenNames[dirent.Name]; !ok { - return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name)) + return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name)) } } } @@ -806,7 +893,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa // The result should contain all children plus "." and "..". if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 { - return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) + return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) } for fd.off < int64(len(ds)) { @@ -978,7 +1065,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { // or directory other than the root, the parent Merkle tree file should // have also been initialized. if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { - return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") + return 0, fd.d.fs.alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } hash, dataSize, err := fd.generateMerkleLocked(ctx) @@ -1033,7 +1120,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { } // measureVerity returns the hash of fd, saved in verityDigest. -func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest usermem.Addr) (uintptr, error) { +func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest hostarch.Addr) (uintptr, error) { t := kernel.TaskFromContext(ctx) if t == nil { return 0, syserror.EINVAL @@ -1051,7 +1138,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm if fd.d.fs.allowRuntimeEnable { return 0, syserror.ENODATA } - return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found") + return 0, fd.d.fs.alertIntegrityViolation("Ioctl measureVerity: no hash found") } // The first part of VerityDigest is the metadata. @@ -1072,11 +1159,11 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm } // Now copy the root hash bytes to the memory after metadata. - _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.hash) + _, err := t.CopyOutBytes(hostarch.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.hash) return 0, err } -func (fd *fileDescription) verityFlags(ctx context.Context, flags usermem.Addr) (uintptr, error) { +func (fd *fileDescription) verityFlags(ctx context.Context, flags hostarch.Addr) (uintptr, error) { f := int32(0) fd.d.hashMu.RLock() @@ -1141,7 +1228,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { return 0, err @@ -1151,7 +1238,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } dataReader := FileReadWriteSeeker{ @@ -1184,7 +1271,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of }) fd.d.hashMu.RUnlock() if err != nil { - return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } return n, err } @@ -1199,6 +1286,24 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op return 0, syserror.EROFS } +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + if err := fd.lowerFD.ConfigureMMap(ctx, opts); err != nil { + return err + } + fd.lowerMappable = opts.Mappable + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef(ctx) + opts.MappingIdentity = nil + } + + // Check if mmap is allowed on the lower filesystem. + if !opts.SentryOwnedContent { + return syserror.ENODEV + } + return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts) +} + // LockBSD implements vfs.FileDescriptionImpl.LockBSD. func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return fd.lowerFD.LockBSD(ctx, ownerPID, t, block) @@ -1224,6 +1329,115 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t return fd.lowerFD.TestPOSIX(ctx, uid, t, r) } +// Translate implements memmap.Mappable.Translate. +func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { + ts, err := fd.lowerMappable.Translate(ctx, required, optional, at) + if err != nil { + return ts, err + } + + // dataSize is the size of the whole file. + dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the xattr does not exist, it + // indicates unexpected modifications to the file system. + if err == syserror.ENODATA { + return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + } + if err != nil { + return ts, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + size, err := strconv.Atoi(dataSize) + if err != nil { + return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + } + + merkleReader := FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + + for _, t := range ts { + // Content integrity relies on sentry owning the backing data. MapInternal is guaranteed + // to fetch sentry owned memory because we disallow verity mmaps otherwise. + ims, err := t.File.MapInternal(memmap.FileRange{t.Offset, t.Offset + t.Source.Length()}, hostarch.Read) + if err != nil { + return nil, err + } + dataReader := mmapReadSeeker{ims, t.Source.Start} + var buf bytes.Buffer + _, err = merkletree.Verify(&merkletree.VerifyParams{ + Out: &buf, + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), + ReadOffset: int64(t.Source.Start), + ReadSize: int64(t.Source.Length()), + Expected: fd.d.hash, + DataAndTreeInSameFile: false, + }) + if err != nil { + return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + } + } + return ts, err +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (fd *fileDescription) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { + return fd.lowerMappable.AddMapping(ctx, ms, ar, offset, writable) +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (fd *fileDescription) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { + fd.lowerMappable.RemoveMapping(ctx, ms, ar, offset, writable) +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (fd *fileDescription) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { + return fd.lowerMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable) +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (fd *fileDescription) InvalidateUnsavable(context.Context) error { + return nil +} + +// mmapReadSeeker is a helper struct used by fileDescription.Translate to pass +// a safemem.BlockSeq pointing to the mapped region as io.ReaderAt. +type mmapReadSeeker struct { + safemem.BlockSeq + Offset uint64 +} + +// ReadAt implements io.ReaderAt.ReadAt. off is the offset into the mapped file. +func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) { + bs := r.BlockSeq + // Adjust the offset into the mapped file to get the offset into the internally + // mapped region. + readOffset := off - int64(r.Offset) + if readOffset < 0 { + return 0, syserror.EINVAL + } + bs.DropFirst64(uint64(readOffset)) + view := bs.TakeFirst64(uint64(len(p))) + dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p)) + n, err := safemem.CopySeq(dst, view) + return int(n), err +} + // FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as // io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. type FileReadWriteSeeker struct { diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 57bd65202..5c78a0019 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -89,10 +89,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, AllowUserMount: true, }) + data := "root_name=" + rootMerkleFilename mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: data, InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, LowerName: "tmpfs", Alg: hashAlg, AllowRuntimeEnable: true, diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index 300b7ccce..66fa1ad40 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -13,8 +13,8 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/fd", + "//pkg/hostarch", "//pkg/log", - "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go index c47b96b54..285ea9050 100644 --- a/pkg/sentry/hostmm/hostmm.go +++ b/pkg/sentry/hostmm/hostmm.go @@ -23,8 +23,8 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/usermem" ) // NotifyCurrentMemcgPressureCallback requests that f is called whenever the @@ -88,7 +88,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) if n != sizeofUint64 { panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64)) } - val := usermem.ByteOrder.Uint64(buf[:]) + val := hostarch.ByteOrder.Uint64(buf[:]) if val >= stopVal { // Assume this was due to the notifier's "destructor" (the // function returned by NotifyCurrentMemcgPressureCallback @@ -103,7 +103,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) return func() { rw := fd.NewReadWriter(eventFD.FD()) var buf [sizeofUint64]byte - usermem.ByteOrder.PutUint64(buf[:], stopVal) + hostarch.ByteOrder.PutUint64(buf[:], stopVal) for { n, err := rw.Write(buf[:]) if err != nil { diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index c53e3e720..a1ec6daab 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -141,6 +141,7 @@ go_library( srcs = [ "abstract_socket_namespace.go", "aio.go", + "cgroup.go", "context.go", "fd_table.go", "fd_table_refs.go", @@ -178,6 +179,7 @@ go_library( "task.go", "task_acct.go", "task_block.go", + "task_cgroup.go", "task_clone.go", "task_context.go", "task_exec.go", @@ -226,6 +228,7 @@ go_library( "//pkg/eventchannel", "//pkg/fspath", "//pkg/goid", + "//pkg/hostarch", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", @@ -240,6 +243,7 @@ go_library( "//pkg/sentry/fs/lock", "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", + "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/fsimpl/timerfd", @@ -294,6 +298,7 @@ go_test( deps = [ "//pkg/abi", "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/arch", "//pkg/sentry/contexttest", "//pkg/sentry/fs", @@ -305,6 +310,5 @@ go_test( "//pkg/sentry/usage", "//pkg/sync", "//pkg/syserror", - "//pkg/usermem", ], ) diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go new file mode 100644 index 000000000..1f1c63f37 --- /dev/null +++ b/pkg/sentry/kernel/cgroup.go @@ -0,0 +1,281 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "bytes" + "fmt" + "sort" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +// InvalidCgroupHierarchyID indicates an uninitialized hierarchy ID. +const InvalidCgroupHierarchyID uint32 = 0 + +// CgroupControllerType is the name of a cgroup controller. +type CgroupControllerType string + +// CgroupController is the common interface to cgroup controllers available to +// the entire sentry. The controllers themselves are defined by cgroupfs. +// +// Callers of this interface are often unable access synchronization needed to +// ensure returned values remain valid. Some of values returned from this +// interface are thus snapshots in time, and may become stale. This is ok for +// many callers like procfs. +type CgroupController interface { + // Returns the type of this cgroup controller (ex "memory", "cpu"). Returned + // value is valid for the lifetime of the controller. + Type() CgroupControllerType + + // Hierarchy returns the ID of the hierarchy this cgroup controller is + // attached to. Returned value is valid for the lifetime of the controller. + HierarchyID() uint32 + + // Filesystem returns the filesystem this controller is attached to. + // Returned value is valid for the lifetime of the controller. + Filesystem() *vfs.Filesystem + + // RootCgroup returns the root cgroup for this controller. Returned value is + // valid for the lifetime of the controller. + RootCgroup() Cgroup + + // NumCgroups returns the number of cgroups managed by this controller. + // Returned value is a snapshot in time. + NumCgroups() uint64 + + // Enabled returns whether this controller is enabled. Returned value is a + // snapshot in time. + Enabled() bool +} + +// Cgroup represents a named pointer to a cgroup in cgroupfs. When a task enters +// a cgroup, it holds a reference on the underlying dentry pointing to the +// cgroup. +// +// +stateify savable +type Cgroup struct { + *kernfs.Dentry + CgroupImpl +} + +func (c *Cgroup) decRef() { + c.Dentry.DecRef(context.Background()) +} + +// Path returns the absolute path of c, relative to its hierarchy root. +func (c *Cgroup) Path() string { + return c.FSLocalPath() +} + +// HierarchyID returns the id of the hierarchy that contains this cgroup. +func (c *Cgroup) HierarchyID() uint32 { + // Note: a cgroup is guaranteed to have at least one controller. + return c.Controllers()[0].HierarchyID() +} + +// CgroupImpl is the common interface to cgroups. +type CgroupImpl interface { + Controllers() []CgroupController + Enter(t *Task) + Leave(t *Task) +} + +// hierarchy represents a cgroupfs filesystem instance, with a unique set of +// controllers attached to it. Multiple cgroupfs mounts may reference the same +// hierarchy. +// +// +stateify savable +type hierarchy struct { + id uint32 + // These are a subset of the controllers in CgroupRegistry.controllers, + // grouped here by hierarchy for conveninent lookup. + controllers map[CgroupControllerType]CgroupController + // fs is not owned by hierarchy. The FS is responsible for unregistering the + // hierarchy on destruction, which removes this association. + fs *vfs.Filesystem +} + +func (h *hierarchy) match(ctypes []CgroupControllerType) bool { + if len(ctypes) != len(h.controllers) { + return false + } + for _, ty := range ctypes { + if _, ok := h.controllers[ty]; !ok { + return false + } + } + return true +} + +// CgroupRegistry tracks the active set of cgroup controllers on the system. +// +// +stateify savable +type CgroupRegistry struct { + // lastHierarchyID is the id of the last allocated cgroup hierarchy. Valid + // ids are from 1 to math.MaxUint32. Must be accessed through atomic ops. + // + lastHierarchyID uint32 + + mu sync.Mutex `state:"nosave"` + + // controllers is the set of currently known cgroup controllers on the + // system. Protected by mu. + // + // +checklocks:mu + controllers map[CgroupControllerType]CgroupController + + // hierarchies is the active set of cgroup hierarchies. Protected by mu. + // + // +checklocks:mu + hierarchies map[uint32]hierarchy +} + +func newCgroupRegistry() *CgroupRegistry { + return &CgroupRegistry{ + controllers: make(map[CgroupControllerType]CgroupController), + hierarchies: make(map[uint32]hierarchy), + } +} + +// nextHierarchyID returns a newly allocated, unique hierarchy ID. +func (r *CgroupRegistry) nextHierarchyID() (uint32, error) { + if hid := atomic.AddUint32(&r.lastHierarchyID, 1); hid != 0 { + return hid, nil + } + return InvalidCgroupHierarchyID, fmt.Errorf("cgroup hierarchy ID overflow") +} + +// FindHierarchy returns a cgroup filesystem containing exactly the set of +// controllers named in names. If no such FS is found, FindHierarchy return +// nil. FindHierarchy takes a reference on the returned FS, which is transferred +// to the caller. +func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Filesystem { + r.mu.Lock() + defer r.mu.Unlock() + + for _, h := range r.hierarchies { + if h.match(ctypes) { + h.fs.IncRef() + return h.fs + } + } + + return nil +} + +// Register registers the provided set of controllers with the registry as a new +// hierarchy. If any controller is already registered, the function returns an +// error without modifying the registry. The hierarchy can be later referenced +// by the returned id. +func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if len(cs) == 0 { + return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers") + } + + for _, c := range cs { + if _, ok := r.controllers[c.Type()]; ok { + return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy") + } + } + + hid, err := r.nextHierarchyID() + if err != nil { + return hid, err + } + + h := hierarchy{ + id: hid, + controllers: make(map[CgroupControllerType]CgroupController), + fs: cs[0].Filesystem(), + } + for _, c := range cs { + n := c.Type() + r.controllers[n] = c + h.controllers[n] = c + } + r.hierarchies[hid] = h + return hid, nil +} + +// Unregister removes a previously registered hierarchy from the registry. If +// the controller was not previously registered, Unregister is a no-op. +func (r *CgroupRegistry) Unregister(hid uint32) { + r.mu.Lock() + defer r.mu.Unlock() + + if h, ok := r.hierarchies[hid]; ok { + for name, _ := range h.controllers { + delete(r.controllers, name) + } + delete(r.hierarchies, hid) + } +} + +// computeInitialGroups takes a reference on each of the returned cgroups. The +// caller takes ownership of this returned reference. +func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[Cgroup]struct{} { + r.mu.Lock() + defer r.mu.Unlock() + + ctlSet := make(map[CgroupControllerType]CgroupController) + cgset := make(map[Cgroup]struct{}) + + // Remember controllers from the inherited cgroups set... + for cg, _ := range inherit { + cg.IncRef() // Ref transferred to caller. + for _, ctl := range cg.Controllers() { + ctlSet[ctl.Type()] = ctl + cgset[cg] = struct{}{} + } + } + + // ... and add the root cgroups of all the missing controllers. + for name, ctl := range r.controllers { + if _, ok := ctlSet[name]; !ok { + cg := ctl.RootCgroup() + cg.IncRef() // Ref transferred to caller. + cgset[cg] = struct{}{} + } + } + return cgset +} + +// GenerateProcCgroups writes the contents of /proc/cgroups to buf. +func (r *CgroupRegistry) GenerateProcCgroups(buf *bytes.Buffer) { + r.mu.Lock() + entries := make([]string, 0, len(r.controllers)) + for _, c := range r.controllers { + en := 0 + if c.Enabled() { + en = 1 + } + entries = append(entries, fmt.Sprintf("%s\t%d\t%d\t%d\n", c.Type(), c.HierarchyID(), c.NumCgroups(), en)) + } + r.mu.Unlock() + + sort.Strings(entries) + fmt.Fprint(buf, "#subsys_name\thierarchy\tnum_cgroups\tenabled\n") + for _, e := range entries { + fmt.Fprint(buf, e) + } +} diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index 7ecbd29ab..564c3d42e 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fdnotifier", + "//pkg/hostarch", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 2aca02fd5..4466fbc9d 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -186,7 +187,7 @@ func (e *EventOperations) read(ctx context.Context, dst usermem.IOSequence) erro e.wq.Notify(waiter.WritableEvents) var buf [8]byte - usermem.ByteOrder.PutUint64(buf[:], val) + hostarch.ByteOrder.PutUint64(buf[:], val) _, err := dst.CopyOut(ctx, buf[:]) return err } @@ -194,7 +195,7 @@ func (e *EventOperations) read(ctx context.Context, dst usermem.IOSequence) erro // Must be called with e.mu locked. func (e *EventOperations) hostWrite(val uint64) error { var buf [8]byte - usermem.ByteOrder.PutUint64(buf[:], val) + hostarch.ByteOrder.PutUint64(buf[:], val) _, err := unix.Write(e.hostfd, buf[:]) if err == unix.EWOULDBLOCK { return syserror.ErrWouldBlock @@ -207,7 +208,7 @@ func (e *EventOperations) write(ctx context.Context, src usermem.IOSequence) err if _, err := src.CopyIn(ctx, buf[:]); err != nil { return err } - val := usermem.ByteOrder.Uint64(buf[:]) + val := hostarch.ByteOrder.Uint64(buf[:]) return e.Signal(val) } diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 041e3d4ca..a75686cf3 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -37,6 +37,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/sentry/memmap", "//pkg/sync", @@ -52,8 +53,8 @@ go_test( library = ":futex", deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/sync", - "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index e4dcc4d40..0427cf3f4 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -20,10 +20,10 @@ package futex import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // KeyKind indicates the type of a Key. @@ -83,8 +83,8 @@ func (k *Key) clone() Key { } // Preconditions: k.Kind == KindPrivate or KindSharedPrivate. -func (k *Key) addr() usermem.Addr { - return usermem.Addr(k.Offset) +func (k *Key) addr() hostarch.Addr { + return hostarch.Addr(k.Offset) } // matches returns true if a wakeup on k2 should wake a waiter waiting on k. @@ -97,14 +97,14 @@ func (k *Key) matches(k2 *Key) bool { type Target interface { context.Context - // SwapUint32 gives access to usermem.IO.SwapUint32. - SwapUint32(addr usermem.Addr, new uint32) (uint32, error) + // SwapUint32 gives access to hostarch.IO.SwapUint32. + SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) - // CompareAndSwap gives access to usermem.IO.CompareAndSwapUint32. - CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) + // CompareAndSwap gives access to hostarch.IO.CompareAndSwapUint32. + CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) - // LoadUint32 gives access to usermem.IO.LoadUint32. - LoadUint32(addr usermem.Addr) (uint32, error) + // LoadUint32 gives access to hostarch.IO.LoadUint32. + LoadUint32(addr hostarch.Addr) (uint32, error) // GetSharedKey returns a Key with kind KindSharedPrivate or // KindSharedMappable corresponding to the memory mapped at address addr. @@ -112,11 +112,11 @@ type Target interface { // If GetSharedKey returns a Key with a non-nil MappingIdentity, a // reference is held on the MappingIdentity, which must be dropped by the // caller when the Key is no longer in use. - GetSharedKey(addr usermem.Addr) (Key, error) + GetSharedKey(addr hostarch.Addr) (Key, error) } // check performs a basic equality check on the given address. -func check(t Target, addr usermem.Addr, val uint32) error { +func check(t Target, addr hostarch.Addr, val uint32) error { cur, err := t.LoadUint32(addr) if err != nil { return err @@ -128,7 +128,7 @@ func check(t Target, addr usermem.Addr, val uint32) error { } // atomicOp performs a complex operation on the given address. -func atomicOp(t Target, addr usermem.Addr, opIn uint32) (bool, error) { +func atomicOp(t Target, addr hostarch.Addr, opIn uint32) (bool, error) { opType := (opIn >> 28) & 0xf cmp := (opIn >> 24) & 0xf opArg := (opIn >> 12) & 0xfff @@ -328,7 +328,7 @@ const ( ) // getKey returns a Key representing address addr in c. -func getKey(t Target, addr usermem.Addr, private bool) (Key, error) { +func getKey(t Target, addr hostarch.Addr, private bool) (Key, error) { // Ensure the address is aligned. // It must be a DWORD boundary. if addr&0x3 != 0 { @@ -341,7 +341,7 @@ func getKey(t Target, addr usermem.Addr, private bool) (Key, error) { } // bucketIndexForAddr returns the index into Manager.buckets for addr. -func bucketIndexForAddr(addr usermem.Addr) uintptr { +func bucketIndexForAddr(addr hostarch.Addr) uintptr { // - The bottom 2 bits of addr must be 0, per getKey. // // - On amd64, the top 16 bits of addr (bits 48-63) must be equal to bit 47 @@ -448,7 +448,7 @@ func (m *Manager) lockBuckets(k1, k2 *Key) (*bucket, *bucket) { // Wake wakes up to n waiters matching the bitmask on the given addr. // The number of waiters woken is returned. -func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32, n int) (int, error) { +func (m *Manager) Wake(t Target, addr hostarch.Addr, private bool, bitmask uint32, n int) (int, error) { // This function is very hot; avoid defer. k, err := getKey(t, addr, private) if err != nil { @@ -463,7 +463,7 @@ func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32 return r, nil } -func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, checkval bool, val uint32, nwake int, nreq int) (int, error) { +func (m *Manager) doRequeue(t Target, addr, naddr hostarch.Addr, private bool, checkval bool, val uint32, nwake int, nreq int) (int, error) { k1, err := getKey(t, addr, private) if err != nil { return 0, err @@ -498,14 +498,14 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch // Requeue wakes up to nwake waiters on the given addr, and unconditionally // requeues up to nreq waiters on naddr. -func (m *Manager) Requeue(t Target, addr, naddr usermem.Addr, private bool, nwake int, nreq int) (int, error) { +func (m *Manager) Requeue(t Target, addr, naddr hostarch.Addr, private bool, nwake int, nreq int) (int, error) { return m.doRequeue(t, addr, naddr, private, false, 0, nwake, nreq) } // RequeueCmp atomically checks that the addr contains val (via the Target), // wakes up to nwake waiters on addr and then unconditionally requeues nreq // waiters on naddr. -func (m *Manager) RequeueCmp(t Target, addr, naddr usermem.Addr, private bool, val uint32, nwake int, nreq int) (int, error) { +func (m *Manager) RequeueCmp(t Target, addr, naddr hostarch.Addr, private bool, val uint32, nwake int, nreq int) (int, error) { return m.doRequeue(t, addr, naddr, private, true, val, nwake, nreq) } @@ -513,7 +513,7 @@ func (m *Manager) RequeueCmp(t Target, addr, naddr usermem.Addr, private bool, v // waiters unconditionally from addr1, and, based on the original value at addr2 // and a comparison encoded in op, wakes up to nwake2 waiters from addr2. // It returns the total number of waiters woken. -func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwake1 int, nwake2 int, op uint32) (int, error) { +func (m *Manager) WakeOp(t Target, addr1, addr2 hostarch.Addr, private bool, nwake1 int, nwake2 int, op uint32) (int, error) { k1, err := getKey(t, addr1, private) if err != nil { return 0, err @@ -553,7 +553,7 @@ func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwak // enqueues w to be woken by a send to w.C. If WaitPrepare returns nil, the // Waiter must be subsequently removed by calling WaitComplete, whether or not // a wakeup is received on w.C. -func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) error { +func (m *Manager) WaitPrepare(w *Waiter, t Target, addr hostarch.Addr, private bool, val uint32, bitmask uint32) error { k, err := getKey(t, addr, private) if err != nil { return err @@ -631,7 +631,7 @@ func (m *Manager) WaitComplete(w *Waiter, t Target) { // FUTEX_OWNER_DIED is only set by the Linux when robust lists are in use (see // exit_robust_list()). Given we don't support robust lists, although handled // below, it's never set. -func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, private, try bool) (bool, error) { +func (m *Manager) LockPI(w *Waiter, t Target, addr hostarch.Addr, tid uint32, private, try bool) (bool, error) { k, err := getKey(t, addr, private) if err != nil { return false, err @@ -663,7 +663,7 @@ func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, pri return success, nil } -func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint32, b *bucket, try bool) (bool, error) { +func (m *Manager) lockPILocked(w *Waiter, t Target, addr hostarch.Addr, tid uint32, b *bucket, try bool) (bool, error) { for { cur, err := t.LoadUint32(addr) if err != nil { @@ -724,7 +724,7 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3 // The address provided must contain the caller's TID. If there are waiters, // TID of the next waiter (FIFO) is set to the given address, and the waiter // woken up. If there are no waiters, 0 is set to the address. -func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error { +func (m *Manager) UnlockPI(t Target, addr hostarch.Addr, tid uint32, private bool) error { k, err := getKey(t, addr, private) if err != nil { return err @@ -738,7 +738,7 @@ func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool return err } -func (m *Manager) unlockPILocked(t Target, addr usermem.Addr, tid uint32, b *bucket, key *Key) error { +func (m *Manager) unlockPILocked(t Target, addr hostarch.Addr, tid uint32, b *bucket, key *Key) error { cur, err := t.LoadUint32(addr) if err != nil { return err diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go index ba7f95d8a..deba44e5c 100644 --- a/pkg/sentry/kernel/futex/futex_test.go +++ b/pkg/sentry/kernel/futex/futex_test.go @@ -23,8 +23,8 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // testData implements the Target interface, and allows us to @@ -43,23 +43,23 @@ func newTestData(size uint) testData { } } -func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) { +func (t testData) SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) { val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), new) return val, nil } -func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) { +func (t testData) CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) { if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), old, new) { return old, nil } return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil } -func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) { +func (t testData) LoadUint32(addr hostarch.Addr) (uint32, error) { return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil } -func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) { +func (t testData) GetSharedKey(addr hostarch.Addr) (Key, error) { return Key{ Kind: KindSharedMappable, Offset: uint64(addr), @@ -73,7 +73,7 @@ func futexKind(private bool) string { return "shared" } -func newPreparedTestWaiter(t *testing.T, m *Manager, ta Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) *Waiter { +func newPreparedTestWaiter(t *testing.T, m *Manager, ta Target, addr hostarch.Addr, private bool, val uint32, bitmask uint32) *Waiter { w := NewWaiter() if err := m.WaitPrepare(w, ta, addr, private, val, bitmask); err != nil { t.Fatalf("WaitPrepare failed: %v", err) @@ -463,12 +463,12 @@ const ( // Beyond being used as a Locker, this is a simple mechanism for // changing the underlying values for simpler tests. type testMutex struct { - a usermem.Addr + a hostarch.Addr d testData m *Manager } -func newTestMutex(addr usermem.Addr, d testData, m *Manager) *testMutex { +func newTestMutex(addr hostarch.Addr, d testData, m *Manager) *testMutex { return &testMutex{a: addr, d: d, m: m} } diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go index 4fcdfc541..4b943106b 100644 --- a/pkg/sentry/kernel/kcov.go +++ b/pkg/sentry/kernel/kcov.go @@ -22,13 +22,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/coverage" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // kcovAreaSizeMax is the maximum number of uint64 entries allowed in the kcov @@ -130,7 +130,7 @@ func (kcov *Kcov) InitTrace(size uint64) error { // To simplify all the logic around mapping, we require that the length of the // shared region is a multiple of the system page size. - if (8*size)&(usermem.PageSize-1) != 0 { + if (8*size)&(hostarch.PageSize-1) != 0 { return syserror.EINVAL } @@ -286,7 +286,7 @@ func (rw *kcovReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { } // Get internal mappings. - bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, usermem.Read) + bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, hostarch.Read) if err != nil { return 0, err } @@ -314,7 +314,7 @@ func (rw *kcovReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) } // Get internal mapping. - bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, usermem.Write) + bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, hostarch.Write) if err != nil { return 0, err } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 43065b45a..e6e9da898 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -294,6 +294,11 @@ type Kernel struct { // YAMAPtraceScope is the current level of YAMA ptrace restrictions. YAMAPtraceScope int32 + + // cgroupRegistry contains the set of active cgroup controllers on the + // system. It is controller by cgroupfs. Nil if cgroupfs is unavailable on + // the system. + cgroupRegistry *CgroupRegistry } // InitKernelArgs holds arguments to Init. @@ -438,6 +443,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.socketMount = socketMount k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) + + k.cgroupRegistry = newCgroupRegistry() } return nil } @@ -1815,6 +1822,11 @@ func (k *Kernel) SocketMount() *vfs.Mount { return k.socketMount } +// CgroupRegistry returns the cgroup registry. +func (k *Kernel) CgroupRegistry() *CgroupRegistry { + return k.cgroupRegistry +} + // Release releases resources owned by k. // // Precondition: This should only be called after the kernel is fully @@ -1831,3 +1843,43 @@ func (k *Kernel) Release() { k.timekeeper.Destroy() k.vdso.Release(ctx) } + +// PopulateNewCgroupHierarchy moves all tasks into a newly created cgroup +// hierarchy. +// +// Precondition: root must be a new cgroup with no tasks. This implies the +// controllers for root are also new and currently manage no task, which in turn +// implies the new cgroup can be populated without migrating tasks between +// cgroups. +func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) { + k.tasks.mu.RLock() + k.tasks.forEachTaskLocked(func(t *Task) { + if t.exitState != TaskExitNone { + return + } + t.mu.Lock() + t.enterCgroupLocked(root) + t.mu.Unlock() + }) + k.tasks.mu.RUnlock() +} + +// ReleaseCgroupHierarchy moves all tasks out of all cgroups belonging to the +// hierarchy with the provided id. This is intended for use during hierarchy +// teardown, as otherwise the tasks would be orphaned w.r.t to some controllers. +func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) { + k.tasks.mu.RLock() + k.tasks.forEachTaskLocked(func(t *Task) { + if t.exitState != TaskExitNone { + return + } + t.mu.Lock() + for cg, _ := range t.cgroups { + if cg.HierarchyID() == hid { + t.leaveCgroupLocked(cg) + } + } + t.mu.Unlock() + }) + k.tasks.mu.RUnlock() +} diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index beba6d97d..34c617b08 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/abi/linux", "//pkg/amutex", "//pkg/context", + "//pkg/hostarch", "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index d004f2357..06769931a 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -22,18 +22,18 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) const ( // MinimumPipeSize is a hard limit of the minimum size of a pipe. // It corresponds to fs/pipe.c:pipe_min_size. - MinimumPipeSize = usermem.PageSize + MinimumPipeSize = hostarch.PageSize // MaximumPipeSize is a hard limit on the maximum size of a pipe. // It corresponds to fs/pipe.c:pipe_max_size. @@ -41,7 +41,7 @@ const ( // DefaultPipeSize is the system-wide default size of a pipe in bytes. // It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS. - DefaultPipeSize = 16 * usermem.PageSize + DefaultPipeSize = 16 * hostarch.PageSize // atomicIOBytes is the maximum number of bytes that the pipe will // guarantee atomic reads or writes atomically. diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index e524afad5..95b948edb 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -17,6 +17,7 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -274,7 +275,7 @@ func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescripti } src := usermem.IOSequence{ IO: fd, - Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), + Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(count)}), } var ( @@ -302,7 +303,7 @@ func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescripti func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescription, off, count int64) (int64, error) { dst := usermem.IOSequence{ IO: fd, - Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), + Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(count)}), } var ( @@ -328,7 +329,7 @@ func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescript // fd.pipe.Notify(waiter.WritableEvents) after the read is completed. // // Preconditions: fd.pipe.mu must be locked. -func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { +func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts usermem.IOOpts) (int, error) { n, err := fd.pipe.peekLocked(int64(len(dst)), func(srcs safemem.BlockSeq) (uint64, error) { return safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), srcs) }) @@ -340,7 +341,7 @@ func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, // is completed. // // Preconditions: fd.pipe.mu must be locked. -func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { +func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts usermem.IOOpts) (int, error) { n, err := fd.pipe.writeLocked(int64(len(src)), func(dsts safemem.BlockSeq) (uint64, error) { return safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src))) }) @@ -350,7 +351,7 @@ func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, // ZeroOut implements usermem.IO.ZeroOut. // // Preconditions: fd.pipe.mu must be locked. -func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { +func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { n, err := fd.pipe.writeLocked(toZero, func(dsts safemem.BlockSeq) (uint64, error) { return safemem.ZeroSeq(dsts) }) @@ -362,7 +363,7 @@ func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int6 // fd.pipe.Notify(waiter.WritableEvents) after the read is completed. // // Preconditions: fd.pipe.mu must be locked. -func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { +func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { return fd.pipe.peekLocked(ars.NumBytes(), func(srcs safemem.BlockSeq) (uint64, error) { return dst.WriteFromBlocks(srcs) }) @@ -373,25 +374,25 @@ func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst // is completed. // // Preconditions: fd.pipe.mu must be locked. -func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { +func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { return fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) { return src.ReadToBlocks(dsts) }) } // SwapUint32 implements usermem.IO.SwapUint32. -func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) { +func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr hostarch.Addr, new uint32, opts usermem.IOOpts) (uint32, error) { // How did a pipe get passed as the virtual address space to futex(2)? panic("VFSPipeFD.SwapUint32 called unexpectedly") } // CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32. -func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) { +func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr hostarch.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) { panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly") } // LoadUint32 implements usermem.IO.LoadUint32. -func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) { +func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr hostarch.Addr, opts usermem.IOOpts) (uint32, error) { panic("VFSPipeFD.LoadUint32 called unexpectedly") } diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index f5a60e749..57c7659e7 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -1011,7 +1012,7 @@ func (t *Task) ptraceSetOptionsLocked(opts uintptr) error { } // Ptrace implements the ptrace system call. -func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { +func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error { // PTRACE_TRACEME ignores all other arguments. if req == linux.PTRACE_TRACEME { return t.ptraceTraceme() @@ -1190,7 +1191,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { panic(fmt.Sprintf("%#x + %#x overflows. Invalid reg size > %#x", ar.Start, n, ar.Length())) } ar.End = end - return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar)) + return t.CopyOutIovecs(data, hostarch.AddrRangeSeqOf(ar)) case linux.PTRACE_SETREGSET: ars, err := t.CopyInIovecs(data, 1) @@ -1214,8 +1215,8 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { return err } t.p.FullStateChanged() - ar.End -= usermem.Addr(n) - return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar)) + ar.End -= hostarch.Addr(n) + return t.CopyOutIovecs(data, hostarch.AddrRangeSeqOf(ar)) case linux.PTRACE_GETSIGINFO: t.tg.pidns.owner.mu.RLock() @@ -1267,7 +1268,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { case linux.PTRACE_GETEVENTMSG: t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() - _, err := primitive.CopyUint64Out(t, usermem.Addr(data), target.ptraceEventMsg) + _, err := primitive.CopyUint64Out(t, hostarch.Addr(data), target.ptraceEventMsg) return err // PEEKSIGINFO is unimplemented but seems to have no users anywhere. diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go index 7aea3dcd8..5ae05b5c3 100644 --- a/pkg/sentry/kernel/ptrace_amd64.go +++ b/pkg/sentry/kernel/ptrace_amd64.go @@ -18,12 +18,13 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) // ptraceArch implements arch-specific ptrace commands. -func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) error { +func (t *Task) ptraceArch(target *Task, req int64, addr, data hostarch.Addr) error { switch req { case linux.PTRACE_PEEKUSR: // aka PTRACE_PEEKUSER n, err := target.Arch().PtracePeekUser(uintptr(addr)) diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go index d971b96b3..46dd84cbc 100644 --- a/pkg/sentry/kernel/ptrace_arm64.go +++ b/pkg/sentry/kernel/ptrace_arm64.go @@ -17,11 +17,11 @@ package kernel import ( + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // ptraceArch implements arch-specific ptrace commands. -func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) error { +func (t *Task) ptraceArch(target *Task, req int64, addr, data hostarch.Addr) error { return syserror.EIO } diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index 2a9023fdf..4bc5bca44 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/hostcpu" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -43,8 +44,8 @@ type OldRSeqCriticalRegion struct { // application handler while its instruction pointer is in CriticalSection, // set the instruction pointer to Restart and application register r10 (on // amd64) to the former instruction pointer. - CriticalSection usermem.AddrRange - Restart usermem.Addr + CriticalSection hostarch.AddrRange + Restart hostarch.Addr } // RSeqAvailable returns true if t supports (old and new) restartable sequences. @@ -55,7 +56,7 @@ func (t *Task) RSeqAvailable() bool { // SetRSeq registers addr as this thread's rseq structure. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error { +func (t *Task) SetRSeq(addr hostarch.Addr, length, signature uint32) error { if t.rseqAddr != 0 { if t.rseqAddr != addr { return syserror.EINVAL @@ -100,7 +101,7 @@ func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error { // ClearRSeq unregisters addr as this thread's rseq structure. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) ClearRSeq(addr usermem.Addr, length, signature uint32) error { +func (t *Task) ClearRSeq(addr hostarch.Addr, length, signature uint32) error { if t.rseqAddr == 0 { return syserror.EINVAL } @@ -166,7 +167,7 @@ func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error { // CPU number. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) OldRSeqCPUAddr() usermem.Addr { +func (t *Task) OldRSeqCPUAddr() hostarch.Addr { return t.oldRSeqCPUAddr } @@ -177,7 +178,7 @@ func (t *Task) OldRSeqCPUAddr() usermem.Addr { // * t.RSeqAvailable() == true. // * The caller must be running on the task goroutine. // * t's AddressSpace must be active. -func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error { +func (t *Task) SetOldRSeqCPUAddr(addr hostarch.Addr) error { t.oldRSeqCPUAddr = addr // Check that addr is writable. @@ -221,7 +222,7 @@ func (t *Task) oldRSeqCopyOutCPU() error { } buf := t.CopyScratchBuffer(4) - usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) + hostarch.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) _, err := t.CopyOutBytes(t.oldRSeqCPUAddr, buf) return err } @@ -236,8 +237,8 @@ func (t *Task) rseqCopyOutCPU() error { buf := t.CopyScratchBuffer(8) // CPUIDStart and CPUID are the first two fields in linux.RSeq. - usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart - usermem.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID + hostarch.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart + hostarch.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID // N.B. This write is not atomic, but since this occurs on the task // goroutine then as long as userspace uses a single-instruction read // it can't see an invalid value. @@ -251,8 +252,8 @@ func (t *Task) rseqCopyOutCPU() error { func (t *Task) rseqClearCPU() error { buf := t.CopyScratchBuffer(8) // CPUIDStart and CPUID are the first two fields in linux.RSeq. - usermem.ByteOrder.PutUint32(buf, 0) // CPUIDStart - usermem.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID + hostarch.ByteOrder.PutUint32(buf, 0) // CPUIDStart + hostarch.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID // N.B. This write is not atomic, but since this occurs on the task // goroutine then as long as userspace uses a single-instruction read // it can't see an invalid value. @@ -305,7 +306,7 @@ func (t *Task) rseqAddrInterrupt() { return } - critAddr := usermem.Addr(usermem.ByteOrder.Uint64(buf)) + critAddr := hostarch.Addr(hostarch.ByteOrder.Uint64(buf)) if critAddr == 0 { return } @@ -325,7 +326,7 @@ func (t *Task) rseqAddrInterrupt() { return } - start := usermem.Addr(cs.Start) + start := hostarch.Addr(cs.Start) critRange, ok := start.ToRange(cs.PostCommitOffset) if !ok { t.Debugf("Invalid start and offset in %+v", cs) @@ -334,7 +335,7 @@ func (t *Task) rseqAddrInterrupt() { return } - abort := usermem.Addr(cs.Abort) + abort := hostarch.Addr(cs.Abort) if critRange.Contains(abort) { t.Debugf("Abort in critical section in %+v", cs) t.forceSignal(linux.SIGSEGV, false /* unconditional */) @@ -353,7 +354,7 @@ func (t *Task) rseqAddrInterrupt() { return } - sig := usermem.ByteOrder.Uint32(buf) + sig := hostarch.ByteOrder.Uint32(buf) if sig != t.rseqSignature { t.Debugf("Mismatched rseq signature %d != %d", sig, t.rseqSignature) t.forceSignal(linux.SIGSEGV, false /* unconditional */) @@ -376,7 +377,7 @@ func (t *Task) rseqAddrInterrupt() { } // Finally we can actually decide whether or not to restart. - if !critRange.Contains(usermem.Addr(t.Arch().IP())) { + if !critRange.Contains(hostarch.Addr(t.Arch().IP())) { return } @@ -386,7 +387,7 @@ func (t *Task) rseqAddrInterrupt() { // Preconditions: The caller must be running on the task goroutine. func (t *Task) oldRSeqInterrupt() { r := t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion) - if ip := t.Arch().IP(); r.CriticalSection.Contains(usermem.Addr(ip)) { + if ip := t.Arch().IP(); r.CriticalSection.Contains(hostarch.Addr(ip)) { t.Debugf("Interrupted rseq critical section at %#x; restarting at %#x", ip, r.Restart) t.Arch().SetIP(uintptr(r.Restart)) t.Arch().SetOldRSeqInterruptedIP(ip) diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index 8163a6132..a95e174a2 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -18,9 +18,9 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) const maxSyscallFilterInstructions = 1 << 15 @@ -35,11 +35,11 @@ func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input { return bpf.InputBytes{ Data: buf, // Go-marshal always uses the native byte order. - Order: usermem.ByteOrder, + Order: hostarch.ByteOrder, } } -func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo { +func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *arch.SignalInfo { si := &arch.SignalInfo{ Signo: int32(linux.SIGSYS), Errno: errno, @@ -56,7 +56,7 @@ func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalIn // in because vsyscalls do not use the values in t.Arch().) // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip usermem.Addr) linux.BPFAction { +func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip hostarch.Addr) linux.BPFAction { result := linux.BPFAction(t.evaluateSyscallFilters(sysno, args, ip)) action := result & linux.SECCOMP_RET_ACTION switch action { @@ -102,7 +102,7 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u return action } -func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 { +func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip hostarch.Addr) uint32 { data := linux.SeccompData{ Nr: sysno, Arch: t.image.st.AuditNumber, diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 073e14507..1c3c0794f 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -28,6 +28,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 92d60ba78..a73f1bdca 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -38,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -47,7 +48,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Key represents a shm segment key. Analogous to a file name. @@ -197,13 +197,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui } var sizeAligned uint64 - if val, ok := usermem.Addr(size).RoundUp(); ok { + if val, ok := hostarch.Addr(size).RoundUp(); ok { sizeAligned = uint64(val) } else { return nil, syserror.EINVAL } - if numPages := sizeAligned / usermem.PageSize; r.totalPages+numPages > linux.SHMALL { + if numPages := sizeAligned / hostarch.PageSize; r.totalPages+numPages > linux.SHMALL { // "... allocating a segment of the requested size would cause the // system to exceed the system-wide limit on shared memory (SHMALL)." // - man shmget(2) @@ -232,7 +232,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider)) } - effectiveSize := uint64(usermem.Addr(size).MustRoundUp()) + effectiveSize := uint64(hostarch.Addr(size).MustRoundUp()) fr, err := mfp.MemoryFile().Allocate(effectiveSize, usage.Anonymous) if err != nil { return nil, err @@ -267,7 +267,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi r.shms[id] = shm r.keysToShms[key] = shm - r.totalPages += effectiveSize / usermem.PageSize + r.totalPages += effectiveSize / hostarch.PageSize return shm, nil } @@ -318,7 +318,7 @@ func (r *Registry) remove(s *Shm) { } delete(r.shms, s.ID) - r.totalPages -= s.effectiveSize / usermem.PageSize + r.totalPages -= s.effectiveSize / hostarch.PageSize } // Release drops the self-reference of each active shm segment in the registry. @@ -386,7 +386,7 @@ type Shm struct { // effectiveSize of the segment, rounding up to the next page // boundary. Immutable. // - // Invariant: effectiveSize must be a multiple of usermem.PageSize. + // Invariant: effectiveSize must be a multiple of hostarch.PageSize. effectiveSize uint64 // fr is the offset into mfp.MemoryFile() that backs this contents of this @@ -467,7 +467,7 @@ func (s *Shm) Msync(context.Context, memmap.MappableRange) error { } // AddMapping implements memmap.Mappable.AddMapping. -func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) error { +func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ hostarch.AddrRange, _ uint64, _ bool) error { s.mu.Lock() defer s.mu.Unlock() s.attachTime = ktime.NowFromContext(ctx) @@ -482,7 +482,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.A } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) { +func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ hostarch.AddrRange, _ uint64, _ bool) { s.mu.Lock() defer s.mu.Unlock() // RemoveMapping may be called during task exit, when ctx @@ -503,12 +503,12 @@ func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ userme } // CopyMapping implements memmap.Mappable.CopyMapping. -func (*Shm) CopyMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, usermem.AddrRange, uint64, bool) error { +func (*Shm) CopyMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, hostarch.AddrRange, uint64, bool) error { return nil } // Translate implements memmap.Mappable.Translate. -func (s *Shm) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (s *Shm) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { var err error if required.End > s.fr.Length() { err = &memmap.BusError{syserror.EFAULT} @@ -519,7 +519,7 @@ func (s *Shm) Translate(ctx context.Context, required, optional memmap.MappableR Source: source, File: s.mfp.MemoryFile(), Offset: s.fr.Start + source.Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }, }, err } @@ -543,7 +543,7 @@ type AttachOpts struct { // // Postconditions: The returned MMapOpts are valid only as long as a reference // continues to be held on s. -func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) { +func (s *Shm) ConfigureAttach(ctx context.Context, addr hostarch.Addr, opts AttachOpts) (memmap.MMapOpts, error) { s.mu.Lock() defer s.mu.Unlock() if s.pendingDestruction && s.ReadRefs() == 0 { @@ -565,12 +565,12 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts Attac Offset: 0, Addr: addr, Fixed: opts.Remap, - Perms: usermem.AccessType{ + Perms: hostarch.AccessType{ Read: true, Write: !opts.Readonly, Execute: opts.Execute, }, - MaxPerms: usermem.AnyAccess, + MaxPerms: hostarch.AnyAccess, Mappable: s, MappingIdentity: s, }, nil diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 332bdb8e8..953d4310e 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -20,9 +20,9 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // maxSyscallNum is the highest supported syscall number. @@ -243,7 +243,7 @@ type SyscallTable struct { // Emulate is a collection of instruction addresses to emulate. The // keys are addresses, and the values are system call numbers. - Emulate map[usermem.Addr]uintptr + Emulate map[hostarch.Addr]uintptr // The function to call in case of a missing system call. Missing MissingFn @@ -316,7 +316,7 @@ func (s *SyscallTable) Init() { } if s.Emulate == nil { // Ensure non-nil emulate table. - s.Emulate = make(map[usermem.Addr]uintptr) + s.Emulate = make(map[hostarch.Addr]uintptr) } max := s.MaxSysno() // Checked during RegisterSyscallTable. @@ -359,7 +359,7 @@ func (s *SyscallTable) LookupNo(name string) (uintptr, error) { } // LookupEmulate looks up an emulation syscall number. -func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) { +func (s *SyscallTable) LookupEmulate(addr hostarch.Addr) (uintptr, bool) { sysno, ok := s.Emulate[addr] return sysno, ok } diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 36141dd09..be1371855 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -33,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -470,7 +470,7 @@ type Task struct { // ThreadID to 0, and wake any futex waiters. // // cleartid is exclusive to the task goroutine. - cleartid usermem.Addr + cleartid hostarch.Addr // This is mostly a fake cpumask just for sched_set/getaffinity as we // don't really control the affinity. @@ -540,12 +540,12 @@ type Task struct { // oldRSeqCPUAddr is a pointer to the userspace old rseq CPU variable. // // oldRSeqCPUAddr is exclusive to the task goroutine. - oldRSeqCPUAddr usermem.Addr + oldRSeqCPUAddr hostarch.Addr // rseqAddr is a pointer to the userspace linux.RSeq structure. // // rseqAddr is exclusive to the task goroutine. - rseqAddr usermem.Addr + rseqAddr hostarch.Addr // rseqSignature is the signature that the rseq abort IP must be signed // with. @@ -575,7 +575,7 @@ type Task struct { // robustList is a pointer to the head of the tasks's robust futex // list. - robustList usermem.Addr + robustList hostarch.Addr // startTime is the real time at which the task started. It is set when // a Task is created or invokes execve(2). @@ -587,6 +587,12 @@ type Task struct { // // kcov is exclusive to the task goroutine. kcov *Kcov + + // cgroups is the set of cgroups this task belongs to. This may be empty if + // no cgroup controllers are enabled. Protected by mu. + // + // +checklocks:mu + cgroups map[Cgroup]struct{} } func (t *Task) savePtraceTracer() *Task { @@ -652,7 +658,7 @@ func (t *Task) Kernel() *Kernel { // SetClearTID sets t's cleartid. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) SetClearTID(addr usermem.Addr) { +func (t *Task) SetClearTID(addr hostarch.Addr) { t.cleartid = addr } diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go new file mode 100644 index 000000000..25d2504fa --- /dev/null +++ b/pkg/sentry/kernel/task_cgroup.go @@ -0,0 +1,138 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "bytes" + "fmt" + "sort" + "strings" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/syserror" +) + +// EnterInitialCgroups moves t into an initial set of cgroups. +// +// Precondition: t isn't in any cgroups yet, t.cgs is empty. +// +// +checklocksignore parent.mu is conditionally acquired. +func (t *Task) EnterInitialCgroups(parent *Task) { + var inherit map[Cgroup]struct{} + if parent != nil { + parent.mu.Lock() + defer parent.mu.Unlock() + inherit = parent.cgroups + } + joinSet := t.k.cgroupRegistry.computeInitialGroups(inherit) + + t.mu.Lock() + defer t.mu.Unlock() + // Transfer ownership of joinSet refs to the task's cgset. + t.cgroups = joinSet + for c, _ := range t.cgroups { + // Since t isn't in any cgroup yet, we can skip the check against + // existing cgroups. + c.Enter(t) + } +} + +// EnterCgroup moves t into c. +func (t *Task) EnterCgroup(c Cgroup) error { + newControllers := make(map[CgroupControllerType]struct{}) + for _, ctl := range c.Controllers() { + newControllers[ctl.Type()] = struct{}{} + } + + t.mu.Lock() + defer t.mu.Unlock() + + for oldCG, _ := range t.cgroups { + for _, oldCtl := range oldCG.Controllers() { + if _, ok := newControllers[oldCtl.Type()]; ok { + // Already in a cgroup with the same controller as one of the + // new ones. Requires migration between cgroups. + // + // TODO(b/183137098): Implement cgroup migration. + log.Warningf("Cgroup migration is not implemented") + return syserror.EBUSY + } + } + } + + // No migration required. + t.enterCgroupLocked(c) + + return nil +} + +// +checklocks:t.mu +func (t *Task) enterCgroupLocked(c Cgroup) { + c.IncRef() + t.cgroups[c] = struct{}{} + c.Enter(t) +} + +// LeaveCgroups removes t out from all its cgroups. +func (t *Task) LeaveCgroups() { + t.mu.Lock() + defer t.mu.Unlock() + for c, _ := range t.cgroups { + t.leaveCgroupLocked(c) + } +} + +// +checklocks:t.mu +func (t *Task) leaveCgroupLocked(c Cgroup) { + c.Leave(t) + delete(t.cgroups, c) + c.decRef() +} + +// taskCgroupEntry represents a line in /proc/<pid>/cgroup, and is used to +// format a cgroup for display. +type taskCgroupEntry struct { + hierarchyID uint32 + controllers string + path string +} + +// GenerateProcTaskCgroup writes the contents of /proc/<pid>/cgroup for t to buf. +func (t *Task) GenerateProcTaskCgroup(buf *bytes.Buffer) { + t.mu.Lock() + defer t.mu.Unlock() + + cgEntries := make([]taskCgroupEntry, 0, len(t.cgroups)) + for c, _ := range t.cgroups { + ctls := c.Controllers() + ctlNames := make([]string, 0, len(ctls)) + for _, ctl := range ctls { + ctlNames = append(ctlNames, string(ctl.Type())) + } + + cgEntries = append(cgEntries, taskCgroupEntry{ + // Note: We're guaranteed to have at least one controller, and all + // controllers are guaranteed to be on the same hierarchy. + hierarchyID: ctls[0].HierarchyID(), + controllers: strings.Join(ctlNames, ","), + path: c.Path(), + }) + } + + sort.Slice(cgEntries, func(i, j int) bool { return cgEntries[i].hierarchyID > cgEntries[j].hierarchyID }) + for _, cgE := range cgEntries { + fmt.Fprintf(buf, "%d:%s:%s\n", cgE.hierarchyID, cgE.controllers, cgE.path) + } +} diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index f305e69c0..405771f3f 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -85,12 +86,12 @@ type CloneOptions struct { // Stack is the initial stack pointer of the new task. If Stack is 0, the // new task will start with the same stack pointer as its parent. - Stack usermem.Addr + Stack hostarch.Addr // If SetTLS is true, set the new task's TLS (thread-local storage) // descriptor to TLS. If SetTLS is false, TLS is ignored. SetTLS bool - TLS usermem.Addr + TLS hostarch.Addr // If ChildClearTID is true, when the child exits, 0 is written to the // address ChildTID in the child's memory, and if the write is successful a @@ -101,7 +102,7 @@ type CloneOptions struct { // Linux, failed writes are silently ignored.) ChildClearTID bool ChildSetTID bool - ChildTID usermem.Addr + ChildTID hostarch.Addr // If ParentSetTID is true, the child's thread ID (in the parent's PID // namespace) is written to address ParentTID in the parent's memory. (As @@ -112,7 +113,7 @@ type CloneOptions struct { // and child's memory, but this is a documentation error fixed by // 87ab04792ced ("clone.2: Fix description of CLONE_PARENT_SETTID"). ParentSetTID bool - ParentTID usermem.Addr + ParentTID hostarch.Addr // If Vfork is true, place the parent in vforkStop until the cloned task // releases its TaskImage. @@ -268,7 +269,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } tg := t.tg - rseqAddr := usermem.Addr(0) + rseqAddr := hostarch.Addr(0) rseqSignature := uint32(0) if opts.NewThreadGroup { if tg.mounts != nil { diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index ad59e4f60..b1af1a7ef 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -275,6 +275,10 @@ func (*runExitMain) execute(t *Task) taskRunState { t.fsContext.DecRef(t) t.fdTable.DecRef(t) + // Detach task from all cgroups. This must happen before potentially the + // last ref to the cgroupfs mount is dropped below. + t.LeaveCgroups() + t.mu.Lock() if t.mountNamespaceVFS2 != nil { t.mountNamespaceVFS2.DecRef(t) diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index 195c7da9b..4dc41b82b 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -16,6 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/usermem" @@ -30,33 +31,33 @@ func (t *Task) Futex() *futex.Manager { } // SwapUint32 implements futex.Target.SwapUint32. -func (t *Task) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) { +func (t *Task) SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) { return t.MemoryManager().SwapUint32(t, addr, new, usermem.IOOpts{ AddressSpaceActive: true, }) } // CompareAndSwapUint32 implements futex.Target.CompareAndSwapUint32. -func (t *Task) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) { +func (t *Task) CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) { return t.MemoryManager().CompareAndSwapUint32(t, addr, old, new, usermem.IOOpts{ AddressSpaceActive: true, }) } // LoadUint32 implements futex.Target.LoadUint32. -func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) { +func (t *Task) LoadUint32(addr hostarch.Addr) (uint32, error) { return t.MemoryManager().LoadUint32(t, addr, usermem.IOOpts{ AddressSpaceActive: true, }) } // GetSharedKey implements futex.Target.GetSharedKey. -func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) { +func (t *Task) GetSharedKey(addr hostarch.Addr) (futex.Key, error) { return t.MemoryManager().GetSharedFutexKey(t, addr) } // GetRobustList sets the robust futex list for the task. -func (t *Task) GetRobustList() usermem.Addr { +func (t *Task) GetRobustList() hostarch.Addr { t.mu.Lock() addr := t.robustList t.mu.Unlock() @@ -64,7 +65,7 @@ func (t *Task) GetRobustList() usermem.Addr { } // SetRobustList sets the robust futex list for the task. -func (t *Task) SetRobustList(addr usermem.Addr) { +func (t *Task) SetRobustList(addr hostarch.Addr) { t.mu.Lock() t.robustList = addr t.mu.Unlock() @@ -84,28 +85,28 @@ func (t *Task) exitRobustList() { } var rl linux.RobustListHead - if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil { + if _, err := rl.CopyIn(t, hostarch.Addr(addr)); err != nil { return } next := primitive.Uint64(rl.List) done := 0 - var pendingLockAddr usermem.Addr + var pendingLockAddr hostarch.Addr if rl.ListOpPending != 0 { - pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset) + pendingLockAddr = hostarch.Addr(rl.ListOpPending + rl.FutexOffset) } // Wake up normal elements. - for usermem.Addr(next) != addr { + for hostarch.Addr(next) != addr { // We traverse to the next element of the list before we // actually wake anything. This prevents the race where waking // this futex causes a modification of the list. - thisLockAddr := usermem.Addr(uint64(next) + rl.FutexOffset) + thisLockAddr := hostarch.Addr(uint64(next) + rl.FutexOffset) // Try to decode the next element in the list before waking the // current futex. But don't check the error until after we've // woken the current futex. Linux does it in this order too - _, nextErr := next.CopyIn(t, usermem.Addr(next)) + _, nextErr := next.CopyIn(t, hostarch.Addr(next)) // Wakeup the current futex if it's not pending. if thisLockAddr != pendingLockAddr { @@ -133,7 +134,7 @@ func (t *Task) exitRobustList() { } // wakeRobustListOne wakes a single futex from the robust list. -func (t *Task) wakeRobustListOne(addr usermem.Addr) { +func (t *Task) wakeRobustListOne(addr hostarch.Addr) { // Bit 0 in address signals PI futex. pi := addr&1 == 1 addr = addr &^ 1 diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go index ce5fbd299..bd5543d4e 100644 --- a/pkg/sentry/kernel/task_image.go +++ b/pkg/sentry/kernel/task_image.go @@ -19,12 +19,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/usermem" ) var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC) @@ -129,7 +129,7 @@ func (t *Task) Stack() *arch.Stack { return &arch.Stack{ Arch: t.Arch(), IO: t.MemoryManager(), - Bottom: usermem.Addr(t.Arch().Stack()), + Bottom: hostarch.Addr(t.Arch().Stack()), } } diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index c70e5e6ce..72b9a0384 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -20,6 +20,7 @@ import ( "sort" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/usermem" ) @@ -108,9 +109,9 @@ func (t *Task) debugDumpStack() { return } t.Debugf("Stack:") - start := usermem.Addr(t.Arch().Stack()) + start := hostarch.Addr(t.Arch().Stack()) // Round addr down to a 16-byte boundary. - start &= ^usermem.Addr(15) + start &= ^hostarch.Addr(15) // Print 16 bytes per line, one byte at a time. for offset := uint64(0); offset < maxStackDebugBytes; offset += 16 { addr, ok := start.AddLength(offset) @@ -127,7 +128,7 @@ func (t *Task) debugDumpStack() { t.Debugf("%x: % x", addr, data[:n]) } if err != nil { - t.Debugf("Error reading stack at address %x: %v", addr+usermem.Addr(n), err) + t.Debugf("Error reading stack at address %x: %v", addr+hostarch.Addr(n), err) break } } @@ -147,9 +148,9 @@ func (t *Task) debugDumpCode() { } t.Debugf("Code:") // Print code on both sides of the instruction register. - start := usermem.Addr(t.Arch().IP()) - maxCodeDebugBytes/2 + start := hostarch.Addr(t.Arch().IP()) - maxCodeDebugBytes/2 // Round addr down to a 16-byte boundary. - start &= ^usermem.Addr(15) + start &= ^hostarch.Addr(15) // Print 16 bytes per line, one byte at a time. for offset := uint64(0); offset < maxCodeDebugBytes; offset += 16 { addr, ok := start.AddLength(offset) @@ -166,7 +167,7 @@ func (t *Task) debugDumpCode() { t.Debugf("%x: % x", addr, data[:n]) } if err != nil { - t.Debugf("Error reading stack at address %x: %v", addr+usermem.Addr(n), err) + t.Debugf("Error reading stack at address %x: %v", addr+hostarch.Addr(n), err) break } } diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 3ccecf4b6..068f25af1 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -23,13 +23,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/goid" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/hostcpu" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // A taskRunState is a reified state in the task state machine. See README.md @@ -148,7 +148,7 @@ func (*runApp) handleCPUIDInstruction(t *Task) error { region := trace.StartRegion(t.traceContext, cpuidRegion) expected := arch.CPUIDInstruction[:] found := make([]byte, len(expected)) - _, err := t.CopyInBytes(usermem.Addr(t.Arch().IP()), found) + _, err := t.CopyInBytes(hostarch.Addr(t.Arch().IP()), found) if err == nil && bytes.Equal(expected, found) { // Skip the cpuid instruction. t.Arch().CPUIDEmulate(t) @@ -307,8 +307,8 @@ func (app *runApp) execute(t *Task) taskRunState { // normally. if at.Any() { region := trace.StartRegion(t.traceContext, faultRegion) - addr := usermem.Addr(info.Addr()) - err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack())) + addr := hostarch.Addr(info.Addr()) + err := t.MemoryManager().HandleUserFault(t, addr, at, hostarch.Addr(t.Arch().Stack())) region.End() if err == nil { // The fault was handled appropriately. diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 75af3af79..c2b9fc08f 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -23,11 +23,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/eventchannel" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -243,7 +243,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // Are executing on the main stack, // or the provided alternate stack? - sp := usermem.Addr(t.Arch().Stack()) + sp := hostarch.Addr(t.Arch().Stack()) // N.B. This is a *copy* of the alternate stack that the user's signal // handler expects to see in its ucontext (even if it's not in use). @@ -251,7 +251,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) if act.IsOnStack() && alt.IsEnabled() { alt.SetOnStack() if !alt.Contains(sp) { - sp = usermem.Addr(alt.Top()) + sp = hostarch.Addr(alt.Top()) } } @@ -652,7 +652,7 @@ func (t *Task) SignalStack() arch.SignalStack { // onSignalStack returns true if the task is executing on the given signal stack. func (t *Task) onSignalStack(alt arch.SignalStack) bool { - sp := usermem.Addr(t.Arch().Stack()) + sp := hostarch.Addr(t.Arch().Stack()) return alt.Contains(sp) } @@ -720,7 +720,7 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a // CopyOutSignalAct converts the given SignalAct into an architecture-specific // type and then copies it out to task memory. -func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error { +func (t *Task) CopyOutSignalAct(addr hostarch.Addr, s *arch.SignalAct) error { n := t.Arch().NewSignalAct() n.SerializeFrom(s) _, err := n.CopyOut(t, addr) @@ -729,7 +729,7 @@ func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error { // CopyInSignalAct copies an architecture-specific sigaction type from task // memory and then converts it into a SignalAct. -func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) { +func (t *Task) CopyInSignalAct(addr hostarch.Addr) (arch.SignalAct, error) { n := t.Arch().NewSignalAct() var s arch.SignalAct if _, err := n.CopyIn(t, addr); err != nil { @@ -741,7 +741,7 @@ func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) { // CopyOutSignalStack converts the given SignalStack into an // architecture-specific type and then copies it out to task memory. -func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error { +func (t *Task) CopyOutSignalStack(addr hostarch.Addr, s *arch.SignalStack) error { n := t.Arch().NewSignalStack() n.SerializeFrom(s) _, err := n.CopyOut(t, addr) @@ -750,7 +750,7 @@ func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error // CopyInSignalStack copies an architecture-specific stack_t from task memory // and then converts it into a SignalStack. -func (t *Task) CopyInSignalStack(addr usermem.Addr) (arch.SignalStack, error) { +func (t *Task) CopyInSignalStack(addr hostarch.Addr) (arch.SignalStack, error) { n := t.Arch().NewSignalStack() var s arch.SignalStack if _, err := n.CopyIn(t, addr); err != nil { diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 36e1384f1..32031cd70 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -25,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // TaskConfig defines the configuration of a new Task (see below). @@ -86,7 +86,7 @@ type TaskConfig struct { MountNamespaceVFS2 *vfs.MountNamespace // RSeqAddr is a pointer to the the userspace linux.RSeq structure. - RSeqAddr usermem.Addr + RSeqAddr hostarch.Addr // RSeqSignature is the signature that the rseq abort IP must be signed // with. @@ -151,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { rseqSignature: cfg.RSeqSignature, futexWaiter: futex.NewWaiter(), containerID: cfg.ContainerID, + cgroups: make(map[Cgroup]struct{}), } t.creds.Store(cfg.Credentials) t.endStopCond.L = &t.tg.signalHandlers.mu @@ -189,6 +190,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { t.parent.children[t] = struct{}{} } + if VFS2Enabled { + t.EnterInitialCgroups(t.parent) + } + if tg.leader == nil { // New thread group. tg.leader = t diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 2e84bd88a..2c658d001 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -22,12 +22,12 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application") @@ -153,7 +153,7 @@ func (t *Task) doSyscall() taskRunState { // Check seccomp filters. The nil check is for performance (as seccomp use // is rare), not needed for correctness. if t.syscallFilters.Load() != nil { - switch r := t.checkSeccompSyscall(int32(sysno), args, usermem.Addr(t.Arch().IP())); r { + switch r := t.checkSeccompSyscall(int32(sysno), args, hostarch.Addr(t.Arch().IP())); r { case linux.SECCOMP_RET_ERRNO, linux.SECCOMP_RET_TRAP: t.Debugf("Syscall %d: denied by seccomp", sysno) return (*runSyscallExit)(nil) @@ -283,12 +283,12 @@ func (*runSyscallExit) execute(t *Task) taskRunState { // doVsyscall is the entry point for a vsyscall invocation of syscall sysno, as // indicated by an execution fault at address addr. doVsyscall returns the // task's next run state. -func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState { +func (t *Task) doVsyscall(addr hostarch.Addr, sysno uintptr) taskRunState { vsyscallCount.Increment() // Grab the caller up front, to make sure there's a sensible stack. caller := t.Arch().Native(uintptr(0)) - if _, err := caller.CopyIn(t, usermem.Addr(t.Arch().Stack())); err != nil { + if _, err := caller.CopyIn(t, hostarch.Addr(t.Arch().Stack())); err != nil { t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err) t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) @@ -322,7 +322,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState { } type runVsyscallAfterPtraceEventSeccomp struct { - addr usermem.Addr + addr hostarch.Addr sysno uintptr caller marshal.Marshallable } @@ -337,7 +337,7 @@ func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState { // currently emulated call. ... The tracer MUST NOT modify rip or rsp." - // Documentation/prctl/seccomp_filter.txt. On Linux, changing orig_ax or ip // causes do_exit(SIGSYS), and changing sp is ignored. - if (sysno != ^uintptr(0) && sysno != r.sysno) || usermem.Addr(t.Arch().IP()) != r.addr { + if (sysno != ^uintptr(0) && sysno != r.sysno) || hostarch.Addr(t.Arch().IP()) != r.addr { t.PrepareExit(ExitStatus{Signo: int(linux.SIGSYS)}) return (*runExit)(nil) } diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index 94dabbcd8..fc6d9438a 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -27,7 +28,7 @@ import ( // MAX_RW_COUNT is the maximum size in bytes of a single read or write. // Reads and writes that exceed this size may be silently truncated. // (Linux: include/linux/fs.h:MAX_RW_COUNT) -var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown()) +var MAX_RW_COUNT = int(hostarch.Addr(math.MaxInt32).RoundDown()) // Activate ensures that the task has an active address space. func (t *Task) Activate() { @@ -49,7 +50,7 @@ func (t *Task) Deactivate() { // data without reflection and pass in a byte slice. // // This Task's AddressSpace must be active. -func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { +func (t *Task) CopyInBytes(addr hostarch.Addr, dst []byte) (int, error) { return t.MemoryManager().CopyIn(t, addr, dst, usermem.IOOpts{ AddressSpaceActive: true, }) @@ -59,7 +60,7 @@ func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { // data without reflection and pass in a byte slice. // // This Task's AddressSpace must be active. -func (t *Task) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { +func (t *Task) CopyOutBytes(addr hostarch.Addr, src []byte) (int, error) { return t.MemoryManager().CopyOut(t, addr, src, usermem.IOOpts{ AddressSpaceActive: true, }) @@ -70,7 +71,7 @@ func (t *Task) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { // user memory that is unmapped or not readable by the user. // // This Task's AddressSpace must be active. -func (t *Task) CopyInString(addr usermem.Addr, maxlen int) (string, error) { +func (t *Task) CopyInString(addr hostarch.Addr, maxlen int) (string, error) { return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxlen, usermem.IOOpts{ AddressSpaceActive: true, }) @@ -90,7 +91,7 @@ func (t *Task) CopyInString(addr usermem.Addr, maxlen int) (string, error) { // { "abc" } => 4 (3 for length, 1 for elements) // // This Task's AddressSpace must be active. -func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([]string, error) { +func (t *Task) CopyInVector(addr hostarch.Addr, maxElemSize, maxTotalSize int) ([]string, error) { var v []string for { argAddr := t.Arch().Native(0) @@ -109,12 +110,12 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([ if maxTotalSize < thisMax { thisMax = maxTotalSize } - arg, err := t.CopyInString(usermem.Addr(t.Arch().Value(argAddr)), thisMax) + arg, err := t.CopyInString(hostarch.Addr(t.Arch().Value(argAddr)), thisMax) if err != nil { return v, err } v = append(v, arg) - addr += usermem.Addr(t.Arch().Width()) + addr += hostarch.Addr(t.Arch().Width()) maxTotalSize -= len(arg) + 1 } return v, nil @@ -126,7 +127,7 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([ // Preconditions: Same as usermem.IO.CopyOut, plus: // * The caller must be running on the task goroutine. // * t's AddressSpace must be active. -func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error { +func (t *Task) CopyOutIovecs(addr hostarch.Addr, src hostarch.AddrRangeSeq) error { switch t.Arch().Width() { case 8: const itemLen = 16 @@ -137,8 +138,8 @@ func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error b := t.CopyScratchBuffer(itemLen) for ; !src.IsEmpty(); src = src.Tail() { ar := src.Head() - usermem.ByteOrder.PutUint64(b[0:8], uint64(ar.Start)) - usermem.ByteOrder.PutUint64(b[8:16], uint64(ar.Length())) + hostarch.ByteOrder.PutUint64(b[0:8], uint64(ar.Start)) + hostarch.ByteOrder.PutUint64(b[8:16], uint64(ar.Length())) if _, err := t.CopyOutBytes(addr, b); err != nil { return err } @@ -153,8 +154,8 @@ func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error } // CopyInIovecs copies an array of numIovecs struct iovecs from the memory -// mapped at addr, converts them to usermem.AddrRanges, and returns them as a -// usermem.AddrRangeSeq. +// mapped at addr, converts them to hostarch.AddrRanges, and returns them as a +// hostarch.AddrRangeSeq. // // CopyInIovecs shares the following properties with Linux's // lib/iov_iter.c:import_iovec() => fs/read_write.c:rw_copy_check_uvector(): @@ -175,42 +176,42 @@ func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error // Preconditions: Same as usermem.IO.CopyIn, plus: // * The caller must be running on the task goroutine. // * t's AddressSpace must be active. -func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRangeSeq, error) { +func (t *Task) CopyInIovecs(addr hostarch.Addr, numIovecs int) (hostarch.AddrRangeSeq, error) { if numIovecs == 0 { - return usermem.AddrRangeSeq{}, nil + return hostarch.AddrRangeSeq{}, nil } - var dst []usermem.AddrRange + var dst []hostarch.AddrRange if numIovecs > 1 { - dst = make([]usermem.AddrRange, 0, numIovecs) + dst = make([]hostarch.AddrRange, 0, numIovecs) } switch t.Arch().Width() { case 8: const itemLen = 16 if _, ok := addr.AddLength(uint64(numIovecs) * itemLen); !ok { - return usermem.AddrRangeSeq{}, syserror.EFAULT + return hostarch.AddrRangeSeq{}, syserror.EFAULT } b := t.CopyScratchBuffer(itemLen) for i := 0; i < numIovecs; i++ { if _, err := t.CopyInBytes(addr, b); err != nil { - return usermem.AddrRangeSeq{}, err + return hostarch.AddrRangeSeq{}, err } - base := usermem.Addr(usermem.ByteOrder.Uint64(b[0:8])) - length := usermem.ByteOrder.Uint64(b[8:16]) + base := hostarch.Addr(hostarch.ByteOrder.Uint64(b[0:8])) + length := hostarch.ByteOrder.Uint64(b[8:16]) if length > math.MaxInt64 { - return usermem.AddrRangeSeq{}, syserror.EINVAL + return hostarch.AddrRangeSeq{}, syserror.EINVAL } ar, ok := t.MemoryManager().CheckIORange(base, int64(length)) if !ok { - return usermem.AddrRangeSeq{}, syserror.EFAULT + return hostarch.AddrRangeSeq{}, syserror.EFAULT } if numIovecs == 1 { // Special case to avoid allocating dst. - return usermem.AddrRangeSeqOf(ar).TakeFirst(MAX_RW_COUNT), nil + return hostarch.AddrRangeSeqOf(ar).TakeFirst(MAX_RW_COUNT), nil } dst = append(dst, ar) @@ -218,7 +219,7 @@ func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRange } default: - return usermem.AddrRangeSeq{}, syserror.ENOSYS + return hostarch.AddrRangeSeq{}, syserror.ENOSYS } // Truncate to MAX_RW_COUNT. @@ -226,13 +227,13 @@ func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRange for i := range dst { dstlen := uint64(dst[i].Length()) if rem := uint64(MAX_RW_COUNT) - total; rem < dstlen { - dst[i].End -= usermem.Addr(dstlen - rem) + dst[i].End -= hostarch.Addr(dstlen - rem) dstlen = rem } total += dstlen } - return usermem.AddrRangeSeqFromSlice(dst), nil + return hostarch.AddrRangeSeqFromSlice(dst), nil } // SingleIOSequence returns a usermem.IOSequence representing [addr, @@ -245,7 +246,7 @@ func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRange // write syscalls in Linux do not use import_single_range(). However they check // access_ok() in fs/read_write.c:vfs_read/vfs_write, and overflowing address // ranges are truncated to MAX_RW_COUNT by fs/read_write.c:rw_verify_area().) -func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOpts) (usermem.IOSequence, error) { +func (t *Task) SingleIOSequence(addr hostarch.Addr, length int, opts usermem.IOOpts) (usermem.IOSequence, error) { if length > MAX_RW_COUNT { length = MAX_RW_COUNT } @@ -255,7 +256,7 @@ func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOp } return usermem.IOSequence{ IO: t.MemoryManager(), - Addrs: usermem.AddrRangeSeqOf(ar), + Addrs: hostarch.AddrRangeSeqOf(ar), Opts: opts, }, nil } @@ -267,7 +268,7 @@ func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOp // IovecsIOSequence is analogous to Linux's lib/iov_iter.c:import_iovec(). // // Preconditions: Same as Task.CopyInIovecs. -func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) { +func (t *Task) IovecsIOSequence(addr hostarch.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) { if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV { return usermem.IOSequence{}, syserror.EINVAL } @@ -317,7 +318,7 @@ func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) { } // CopyInBytes implements marshal.CopyContext.CopyInBytes. -func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { +func (cc *taskCopyContext) CopyInBytes(addr hostarch.Addr, dst []byte) (int, error) { tmm, err := cc.getMemoryManager() if err != nil { return 0, err @@ -327,7 +328,7 @@ func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, erro } // CopyOutBytes implements marshal.CopyContext.CopyOutBytes. -func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { +func (cc *taskCopyContext) CopyOutBytes(addr hostarch.Addr, src []byte) (int, error) { tmm, err := cc.getMemoryManager() if err != nil { return 0, err @@ -360,11 +361,11 @@ func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte { } // CopyInBytes implements marshal.CopyContext.CopyInBytes. -func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { +func (cc *ownTaskCopyContext) CopyInBytes(addr hostarch.Addr, dst []byte) (int, error) { return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts) } // CopyOutBytes implements marshal.CopyContext.CopyOutBytes. -func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { +func (cc *ownTaskCopyContext) CopyOutBytes(addr hostarch.Addr, src []byte) (int, error) { return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts) } diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 09d070ec8..77ad62445 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -114,6 +114,15 @@ func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) { } } +// forEachTaskLocked applies f to each Task in ts. +// +// Preconditions: ts.mu must be locked (for reading or writing). +func (ts *TaskSet) forEachTaskLocked(f func(t *Task)) { + for t := range ts.Root.tids { + f(t) + } +} + // A PIDNamespace represents a PID namespace, a bimap between thread IDs and // tasks. See the pid_namespaces(7) man page for further details. // diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go index cf2f7ca72..dfc3c0719 100644 --- a/pkg/sentry/kernel/timekeeper_test.go +++ b/pkg/sentry/kernel/timekeeper_test.go @@ -17,12 +17,12 @@ package kernel import ( "testing" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/pgalloc" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // mockClocks is a sentrytime.Clocks that simply returns the times in the @@ -54,7 +54,7 @@ func (c *mockClocks) GetTime(id sentrytime.ClockID) (int64, error) { func stateTestClocklessTimekeeper(tb testing.TB) *Timekeeper { ctx := contexttest.Context(tb) mfp := pgalloc.MemoryFileProviderFromContext(ctx) - fr, err := mfp.MemoryFile().Allocate(usermem.PageSize, usage.Anonymous) + fr, err := mfp.MemoryFile().Allocate(hostarch.PageSize, usage.Anonymous) if err != nil { tb.Fatalf("failed to allocate memory: %v", err) } diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index 9e5c2d26f..cc0917504 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -17,10 +17,10 @@ package kernel import ( "fmt" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/usermem" ) // vdsoParams are the parameters exposed to the VDSO. @@ -96,7 +96,7 @@ func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSO // access returns a mapping of the param page. func (v *VDSOParamPage) access() (safemem.Block, error) { - bs, err := v.mfp.MemoryFile().MapInternal(v.fr, usermem.ReadWrite) + bs, err := v.mfp.MemoryFile().MapInternal(v.fr, hostarch.ReadWrite) if err != nil { return safemem.Block{}, err } diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index ab074b400..ecb6603a1 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -18,6 +18,7 @@ go_library( "//pkg/binary", "//pkg/context", "//pkg/cpuid", + "//pkg/hostarch", "//pkg/log", "//pkg/rand", "//pkg/safemem", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index cd9fa4031..e92d9fdc3 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -41,7 +42,7 @@ const ( // maxTotalPhdrSize is the maximum combined size of all program // headers. Linux limits this to one page. - maxTotalPhdrSize = usermem.PageSize + maxTotalPhdrSize = hostarch.PageSize ) var ( @@ -52,8 +53,8 @@ var ( prog64Size = int(binary.Size(elf.Prog64{})) ) -func progFlagsAsPerms(f elf.ProgFlag) usermem.AccessType { - var p usermem.AccessType +func progFlagsAsPerms(f elf.ProgFlag) hostarch.AccessType { + var p hostarch.AccessType if f&elf.PF_R == elf.PF_R { p.Read = true } @@ -75,7 +76,7 @@ type elfInfo struct { arch arch.Arch // entry is the program entry point. - entry usermem.Addr + entry hostarch.Addr // phdrs are the program headers. phdrs []elf.ProgHeader @@ -230,7 +231,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { return elfInfo{ os: os, arch: a, - entry: usermem.Addr(hdr.Entry), + entry: hostarch.Addr(hdr.Entry), phdrs: phdrs, phdrOff: hdr.Phoff, phdrSize: prog64Size, @@ -240,9 +241,9 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // mapSegment maps a phdr into the Task. offset is the offset to apply to // phdr.Vaddr. -func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr *elf.ProgHeader, offset usermem.Addr) error { +func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr *elf.ProgHeader, offset hostarch.Addr) error { // We must make a page-aligned mapping. - adjust := usermem.Addr(phdr.Vaddr).PageOffset() + adjust := hostarch.Addr(phdr.Vaddr).PageOffset() addr, ok := offset.AddLength(phdr.Vaddr) if !ok { @@ -250,14 +251,14 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr ctx.Warningf("Computed segment load address overflows: %#x + %#x", phdr.Vaddr, offset) return syserror.ENOEXEC } - addr -= usermem.Addr(adjust) + addr -= hostarch.Addr(adjust) fileSize := phdr.Filesz + adjust if fileSize < phdr.Filesz { ctx.Infof("Computed segment file size overflows: %#x + %#x", phdr.Filesz, adjust) return syserror.ENOEXEC } - ms, ok := usermem.Addr(fileSize).RoundUp() + ms, ok := hostarch.Addr(fileSize).RoundUp() if !ok { ctx.Infof("fileSize %#x too large", fileSize) return syserror.ENOEXEC @@ -281,7 +282,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr Unmap: true, Private: true, Perms: prot, - MaxPerms: usermem.AnyAccess, + MaxPerms: hostarch.AnyAccess, } defer func() { if mopts.MappingIdentity != nil { @@ -312,7 +313,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr panic(fmt.Sprintf("zeroSize too big? %#x", uint64(zeroSize))) } if _, err := m.ZeroOut(ctx, zeroAddr, zeroSize, usermem.IOOpts{IgnorePermissions: true}); err != nil { - ctx.Warningf("Failed to zero end of page [%#x, %#x): %v", zeroAddr, zeroAddr+usermem.Addr(zeroSize), err) + ctx.Warningf("Failed to zero end of page [%#x, %#x): %v", zeroAddr, zeroAddr+hostarch.Addr(zeroSize), err) return err } } @@ -330,7 +331,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr if !ok { panic(fmt.Sprintf("anonymous memory doesn't fit in pre-sized range? %#x + %#x", addr, mapSize)) } - anonSize, ok := usermem.Addr(memSize - mapSize).RoundUp() + anonSize, ok := hostarch.Addr(memSize - mapSize).RoundUp() if !ok { ctx.Infof("extra anon pages too large: %#x", memSize-mapSize) return syserror.ENOEXEC @@ -339,7 +340,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr // N.B. Linux uses vm_brk_flags to map these pages, which only // honors the X bit, always mapping at least RW. ignoring These // pages are not included in the final brk region. - prot := usermem.ReadWrite + prot := hostarch.ReadWrite if phdr.Flags&elf.PF_X == elf.PF_X { prot.Execute = true } @@ -352,7 +353,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr Fixed: true, Private: true, Perms: prot, - MaxPerms: usermem.AnyAccess, + MaxPerms: hostarch.AnyAccess, }); err != nil { ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err) return err @@ -371,19 +372,19 @@ type loadedELF struct { arch arch.Arch // entry is the entry point of the ELF. - entry usermem.Addr + entry hostarch.Addr // start is the end of the ELF. - start usermem.Addr + start hostarch.Addr // end is the end of the ELF. - end usermem.Addr + end hostarch.Addr // interpter is the path to the ELF interpreter. interpreter string // phdrAddr is the address of the ELF program headers. - phdrAddr usermem.Addr + phdrAddr hostarch.Addr // phdrSize is the size of a single program header in the ELF. phdrSize int @@ -407,14 +408,14 @@ type loadedELF struct { // It does not load the ELF interpreter, or return any auxv entries. // // Preconditions: f is an ELF file. -func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) { +func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset hostarch.Addr) (loadedELF, error) { first := true - var start, end usermem.Addr + var start, end hostarch.Addr var interpreter string for _, phdr := range info.phdrs { switch phdr.Type { case elf.PT_LOAD: - vaddr := usermem.Addr(phdr.Vaddr) + vaddr := hostarch.Addr(phdr.Vaddr) if first { first = false start = vaddr @@ -492,7 +493,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in // Note that the vaddr of the first PT_LOAD segment is ignored when // choosing the load address (even if it is non-zero). The vaddr does // become an offset from that load address. - var offset usermem.Addr + var offset hostarch.Addr if info.sharedObject { totalSize := end - start totalSize, ok := totalSize.RoundUp() @@ -688,8 +689,8 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error // ELF-specific auxv entries. bin.auxv = arch.Auxv{ arch.AuxEntry{linux.AT_PHDR, bin.phdrAddr}, - arch.AuxEntry{linux.AT_PHENT, usermem.Addr(bin.phdrSize)}, - arch.AuxEntry{linux.AT_PHNUM, usermem.Addr(bin.phdrNum)}, + arch.AuxEntry{linux.AT_PHENT, hostarch.Addr(bin.phdrSize)}, + arch.AuxEntry{linux.AT_PHNUM, hostarch.Addr(bin.phdrNum)}, arch.AuxEntry{linux.AT_ENTRY, bin.entry}, } if bin.interpreter != "" { diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index c69b62db9..47e3775a3 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -266,17 +267,17 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V // Add generic auxv entries. auxv := append(loaded.auxv, arch.Auxv{ - arch.AuxEntry{linux.AT_UID, usermem.Addr(c.RealKUID.In(c.UserNamespace).OrOverflow())}, - arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())}, - arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())}, - arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())}, + arch.AuxEntry{linux.AT_UID, hostarch.Addr(c.RealKUID.In(c.UserNamespace).OrOverflow())}, + arch.AuxEntry{linux.AT_EUID, hostarch.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())}, + arch.AuxEntry{linux.AT_GID, hostarch.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())}, + arch.AuxEntry{linux.AT_EGID, hostarch.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())}, // The conditions that require AT_SECURE = 1 never arise. See // kernel.Task.updateCredsForExecLocked. arch.AuxEntry{linux.AT_SECURE, 0}, arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC}, arch.AuxEntry{linux.AT_EXECFN, execfn}, arch.AuxEntry{linux.AT_RANDOM, random}, - arch.AuxEntry{linux.AT_PAGESZ, usermem.PageSize}, + arch.AuxEntry{linux.AT_PAGESZ, hostarch.PageSize}, arch.AuxEntry{linux.AT_SYSINFO_EHDR, vdsoAddr}, }...) auxv = append(auxv, extraAuxv...) diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index a32d37d62..fd54261fd 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -90,7 +91,7 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro var first *elf.ProgHeader var prev *elf.ProgHeader - var prevEnd usermem.Addr + var prevEnd hostarch.Addr for i, phdr := range info.phdrs { if phdr.Type != elf.PT_LOAD { continue @@ -119,7 +120,7 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro return elfInfo{}, syserror.ENOEXEC } - start := usermem.Addr(memoryOffset) + start := hostarch.Addr(memoryOffset) end, ok := start.AddLength(phdr.Memsz) if !ok { log.Warningf("PT_LOAD segment size overflows: %#x + %#x", start, end) @@ -210,7 +211,7 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { } // Then copy it into a VDSO mapping. - size, ok := usermem.Addr(len(vdsodata.Binary)).RoundUp() + size, ok := hostarch.Addr(len(vdsodata.Binary)).RoundUp() if !ok { return nil, fmt.Errorf("VDSO size overflows? %#x", len(vdsodata.Binary)) } @@ -221,7 +222,7 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { return nil, fmt.Errorf("unable to allocate VDSO memory: %v", err) } - ims, err := mf.MapInternal(vdso, usermem.ReadWrite) + ims, err := mf.MapInternal(vdso, hostarch.ReadWrite) if err != nil { mf.DecRef(vdso) return nil, fmt.Errorf("unable to map VDSO memory: %v", err) @@ -234,7 +235,7 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { } // Finally, allocate a param page for this VDSO. - paramPage, err := mf.Allocate(usermem.PageSize, usage.System) + paramPage, err := mf.Allocate(hostarch.PageSize, usage.System) if err != nil { mf.DecRef(vdso) return nil, fmt.Errorf("unable to allocate VDSO param page: %v", err) @@ -266,7 +267,7 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { // compatibility with such binaries, we load the VDSO much like Linux. // // loadVDSO takes a reference on the VDSO and parameter page FrameRegions. -func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) (usermem.Addr, error) { +func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) (hostarch.Addr, error) { if v.os != bin.os { ctx.Warningf("Binary ELF OS %v and VDSO ELF OS %v differ", bin.os, v.os) return 0, syserror.ENOEXEC @@ -297,8 +298,8 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) Fixed: true, Unmap: true, Private: true, - Perms: usermem.Read, - MaxPerms: usermem.Read, + Perms: hostarch.Read, + MaxPerms: hostarch.Read, }) if err != nil { ctx.Infof("Unable to map VDSO param page: %v", err) @@ -318,8 +319,8 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) Fixed: true, Unmap: true, Private: true, - Perms: usermem.Read, - MaxPerms: usermem.AnyAccess, + Perms: hostarch.Read, + MaxPerms: hostarch.AnyAccess, }) if err != nil { ctx.Infof("Unable to map VDSO: %v", err) @@ -349,7 +350,7 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) return 0, syserror.ENOEXEC } segPage := segAddr.RoundDown() - segSize := usermem.Addr(phdr.Memsz) + segSize := hostarch.Addr(phdr.Memsz) segSize, ok = segSize.AddLength(segAddr.PageOffset()) if !ok { ctx.Warningf("PT_LOAD segment memsize %#x + offset %#x overflows", phdr.Memsz, segAddr.PageOffset()) @@ -371,7 +372,7 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) } perms := progFlagsAsPerms(phdr.Flags) - if perms != usermem.Read { + if perms != hostarch.Read { if err := m.MProtect(segPage, uint64(segSize), perms, false); err != nil { ctx.Warningf("Unable to set PT_LOAD segment protections %+v at [%#x, %#x): %v", perms, segAddr, segEnd, err) return 0, syserror.ENOEXEC diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 2c95669cd..c30e88725 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -51,6 +51,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/safemem", "//pkg/syserror", @@ -63,5 +64,5 @@ go_test( size = "small", srcs = ["mapping_set_test.go"], library = ":memmap", - deps = ["//pkg/usermem"], + deps = ["//pkg/hostarch"], ) diff --git a/pkg/sentry/memmap/mapping_set.go b/pkg/sentry/memmap/mapping_set.go index 457ed87f8..32863bb5e 100644 --- a/pkg/sentry/memmap/mapping_set.go +++ b/pkg/sentry/memmap/mapping_set.go @@ -18,7 +18,7 @@ import ( "fmt" "math" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // MappingSet maps offsets into a Mappable to mappings of those offsets. It is @@ -39,7 +39,7 @@ type MappingsOfRange map[MappingOfRange]struct{} // +stateify savable type MappingOfRange struct { MappingSpace MappingSpace - AddrRange usermem.AddrRange + AddrRange hostarch.AddrRange Writable bool } @@ -89,9 +89,9 @@ func (mappingSetFunctions) Merge(r1 MappableRange, val1 MappingsOfRange, r2 Mapp // region with k1. k2 := MappingOfRange{ MappingSpace: k1.MappingSpace, - AddrRange: usermem.AddrRange{ + AddrRange: hostarch.AddrRange{ Start: k1.AddrRange.End, - End: k1.AddrRange.End + usermem.Addr(r2.Length()), + End: k1.AddrRange.End + hostarch.Addr(r2.Length()), }, Writable: k1.Writable, } @@ -102,7 +102,7 @@ func (mappingSetFunctions) Merge(r1 MappableRange, val1 MappingsOfRange, r2 Mapp // OK. Add it to the merged map. merged[MappingOfRange{ MappingSpace: k1.MappingSpace, - AddrRange: usermem.AddrRange{ + AddrRange: hostarch.AddrRange{ Start: k1.AddrRange.Start, End: k2.AddrRange.End, }, @@ -124,11 +124,11 @@ func (mappingSetFunctions) Split(r MappableRange, val MappingsOfRange, split uin // split is a value in MappableRange, we need the offset into the // corresponding MappingsOfRange. - offset := usermem.Addr(split - r.Start) + offset := hostarch.Addr(split - r.Start) for k := range val { k1 := MappingOfRange{ MappingSpace: k.MappingSpace, - AddrRange: usermem.AddrRange{ + AddrRange: hostarch.AddrRange{ Start: k.AddrRange.Start, End: k.AddrRange.Start + offset, }, @@ -138,7 +138,7 @@ func (mappingSetFunctions) Split(r MappableRange, val MappingsOfRange, split uin k2 := MappingOfRange{ MappingSpace: k.MappingSpace, - AddrRange: usermem.AddrRange{ + AddrRange: hostarch.AddrRange{ Start: k.AddrRange.Start + offset, End: k.AddrRange.End, }, @@ -157,18 +157,18 @@ func (mappingSetFunctions) Split(r MappableRange, val MappingsOfRange, split uin // indicating that ms maps addresses [0x4000, 0x6000) to MappableRange [0x0, // 0x2000). Then for subsetRange = [0x1000, 0x2000), subsetMapping returns a // MappingOfRange for which AddrRange = [0x5000, 0x6000). -func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr usermem.Addr, writable bool) MappingOfRange { +func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr hostarch.Addr, writable bool) MappingOfRange { if !wholeRange.IsSupersetOf(subsetRange) { panic(fmt.Sprintf("%v is not a superset of %v", wholeRange, subsetRange)) } offset := subsetRange.Start - wholeRange.Start - start := addr + usermem.Addr(offset) + start := addr + hostarch.Addr(offset) return MappingOfRange{ MappingSpace: ms, - AddrRange: usermem.AddrRange{ + AddrRange: hostarch.AddrRange{ Start: start, - End: start + usermem.Addr(subsetRange.Length()), + End: start + hostarch.Addr(subsetRange.Length()), }, Writable: writable, } @@ -178,7 +178,7 @@ func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr // previously had no mappings. // // Preconditions: Same as Mappable.AddMapping. -func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange { +func (s *MappingSet) AddMapping(ms MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) []MappableRange { mr := MappableRange{offset, offset + uint64(ar.Length())} var mapped []MappableRange seg, gap := s.Find(mr.Start) @@ -205,7 +205,7 @@ func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset ui // MappableRanges that now have no mappings. // // Preconditions: Same as Mappable.RemoveMapping. -func (s *MappingSet) RemoveMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange { +func (s *MappingSet) RemoveMapping(ms MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) []MappableRange { mr := MappableRange{offset, offset + uint64(ar.Length())} var unmapped []MappableRange diff --git a/pkg/sentry/memmap/mapping_set_test.go b/pkg/sentry/memmap/mapping_set_test.go index d39efe38f..5cb81fde7 100644 --- a/pkg/sentry/memmap/mapping_set_test.go +++ b/pkg/sentry/memmap/mapping_set_test.go @@ -15,24 +15,23 @@ package memmap import ( + "gvisor.dev/gvisor/pkg/hostarch" "reflect" "testing" - - "gvisor.dev/gvisor/pkg/usermem" ) type testMappingSpace struct { // Ideally we'd store the full ranges that were invalidated, rather // than individual calls to Invalidate, as they are an implementation // detail, but this is the simplest way for now. - inv []usermem.AddrRange + inv []hostarch.AddrRange } func (n *testMappingSpace) reset() { - n.inv = []usermem.AddrRange{} + n.inv = []hostarch.AddrRange{} } -func (n *testMappingSpace) Invalidate(ar usermem.AddrRange, opts InvalidateOpts) { +func (n *testMappingSpace) Invalidate(ar hostarch.AddrRange, opts InvalidateOpts) { n.inv = append(n.inv, ar) } @@ -40,16 +39,16 @@ func TestAddRemoveMapping(t *testing.T) { set := MappingSet{} ms := &testMappingSpace{} - mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) + mapped := set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true) if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) { t.Errorf("AddMapping: got %+v, wanted %+v", got, want) } - // Mappings (usermem.AddrRanges => memmap.MappableRange): + // Mappings (hostarch.AddrRanges => memmap.MappableRange): // [0x10000, 0x12000) => [0x1000, 0x3000) t.Log(&set) - mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) + mapped = set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true) if len(mapped) != 0 { t.Errorf("AddMapping: got %+v, wanted []", mapped) } @@ -59,7 +58,7 @@ func TestAddRemoveMapping(t *testing.T) { // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) t.Log(&set) - mapped = set.AddMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true) + mapped = set.AddMapping(ms, hostarch.AddrRange{0x30000, 0x31000}, 0x4000, true) if got, want := mapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) { t.Errorf("AddMapping: got %+v, wanted %+v", got, want) } @@ -70,7 +69,7 @@ func TestAddRemoveMapping(t *testing.T) { // [0x30000, 0x31000) => [0x4000, 0x5000) t.Log(&set) - mapped = set.AddMapping(ms, usermem.AddrRange{0x12000, 0x15000}, 0x3000, true) + mapped = set.AddMapping(ms, hostarch.AddrRange{0x12000, 0x15000}, 0x3000, true) if got, want := mapped, []MappableRange{{0x3000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) { t.Errorf("AddMapping: got %+v, wanted %+v", got, want) } @@ -83,7 +82,7 @@ func TestAddRemoveMapping(t *testing.T) { // [0x14000, 0x15000) => [0x5000, 0x6000) t.Log(&set) - unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0x1000, true) + unmapped := set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0x1000, true) if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) { t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) } @@ -95,7 +94,7 @@ func TestAddRemoveMapping(t *testing.T) { // [0x14000, 0x15000) => [0x5000, 0x6000) t.Log(&set) - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) + unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true) if len(unmapped) != 0 { t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) } @@ -106,7 +105,7 @@ func TestAddRemoveMapping(t *testing.T) { // [0x14000, 0x15000) => [0x5000, 0x6000) t.Log(&set) - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x15000}, 0x2000, true) + unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x11000, 0x15000}, 0x2000, true) if got, want := unmapped, []MappableRange{{0x2000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) { t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) } @@ -115,7 +114,7 @@ func TestAddRemoveMapping(t *testing.T) { // [0x30000, 0x31000) => [0x4000, 0x5000) t.Log(&set) - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true) + unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x30000, 0x31000}, 0x4000, true) if got, want := unmapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) { t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) } @@ -125,12 +124,12 @@ func TestInvalidateWholeMapping(t *testing.T) { set := MappingSet{} ms := &testMappingSpace{} - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true) + set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0, true) // Mappings: // [0x10000, 0x11000) => [0, 0x1000) t.Log(&set) set.Invalidate(MappableRange{0, 0x1000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}}; !reflect.DeepEqual(got, want) { + if got, want := ms.inv, []hostarch.AddrRange{{0x10000, 0x11000}}; !reflect.DeepEqual(got, want) { t.Errorf("Invalidate: got %+v, wanted %+v", got, want) } } @@ -139,12 +138,12 @@ func TestInvalidatePartialMapping(t *testing.T) { set := MappingSet{} ms := &testMappingSpace{} - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x13000}, 0, true) + set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x13000}, 0, true) // Mappings: // [0x10000, 0x13000) => [0, 0x3000) t.Log(&set) set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { + if got, want := ms.inv, []hostarch.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { t.Errorf("Invalidate: got %+v, wanted %+v", got, want) } } @@ -153,14 +152,14 @@ func TestInvalidateMultipleMappings(t *testing.T) { set := MappingSet{} ms := &testMappingSpace{} - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true) - set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) + set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0, true) + set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true) // Mappings: // [0x10000, 0x11000) => [0, 0x1000) // [0x12000, 0x13000) => [0x2000, 0x3000) t.Log(&set) set.Invalidate(MappableRange{0, 0x3000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}, {0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { + if got, want := ms.inv, []hostarch.AddrRange{{0x10000, 0x11000}, {0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { t.Errorf("Invalidate: got %+v, wanted %+v", got, want) } } @@ -170,17 +169,17 @@ func TestInvalidateOverlappingMappings(t *testing.T) { ms1 := &testMappingSpace{} ms2 := &testMappingSpace{} - set.AddMapping(ms1, usermem.AddrRange{0x10000, 0x12000}, 0, true) - set.AddMapping(ms2, usermem.AddrRange{0x20000, 0x22000}, 0x1000, true) + set.AddMapping(ms1, hostarch.AddrRange{0x10000, 0x12000}, 0, true) + set.AddMapping(ms2, hostarch.AddrRange{0x20000, 0x22000}, 0x1000, true) // Mappings: // ms1:[0x10000, 0x12000) => [0, 0x2000) // ms2:[0x11000, 0x13000) => [0x1000, 0x3000) t.Log(&set) set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{}) - if got, want := ms1.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { + if got, want := ms1.inv, []hostarch.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want) } - if got, want := ms2.inv, []usermem.AddrRange{{0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { + if got, want := ms2.inv, []hostarch.AddrRange{{0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want) } } @@ -189,7 +188,7 @@ func TestMixedWritableMappings(t *testing.T) { set := MappingSet{} ms := &testMappingSpace{} - mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) + mapped := set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true) if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) { t.Errorf("AddMapping: got %+v, wanted %+v", got, want) } @@ -198,7 +197,7 @@ func TestMixedWritableMappings(t *testing.T) { // [0x10000, 0x12000) writable => [0x1000, 0x3000) t.Log(&set) - mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x22000}, 0x2000, false) + mapped = set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x22000}, 0x2000, false) if got, want := mapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) { t.Errorf("AddMapping: got %+v, wanted %+v", got, want) } @@ -211,14 +210,14 @@ func TestMixedWritableMappings(t *testing.T) { // Unmap should fail because we specified the readonly map address range, but // asked to unmap a writable segment. - unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) + unmapped := set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true) if len(unmapped) != 0 { t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) } // Readonly mapping removed, but writable mapping still exists in the range, // so no mappable range fully unmapped. - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, false) + unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, false) if len(unmapped) != 0 { t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) } @@ -228,7 +227,7 @@ func TestMixedWritableMappings(t *testing.T) { // [0x21000, 0x22000) readonly => [0x3000, 0x4000) t.Log(&set) - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x12000}, 0x2000, true) + unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x11000, 0x12000}, 0x2000, true) if got, want := unmapped, []MappableRange{{0x2000, 0x3000}}; !reflect.DeepEqual(got, want) { t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) } @@ -239,12 +238,12 @@ func TestMixedWritableMappings(t *testing.T) { t.Log(&set) // Unmap should fail since writable bit doesn't match. - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, false) + unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, false) if len(unmapped) != 0 { t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) } - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) + unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true) if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) { t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) } @@ -253,7 +252,7 @@ func TestMixedWritableMappings(t *testing.T) { // [0x21000, 0x22000) readonly => [0x3000, 0x4000) t.Log(&set) - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x21000, 0x22000}, 0x3000, false) + unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x21000, 0x22000}, 0x3000, false) if got, want := unmapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) { t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) } diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 49e21026e..610686ea0 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -19,8 +19,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/usermem" ) // Mappable represents a memory-mappable object, a mutable mapping from uint64 @@ -29,8 +29,8 @@ import ( // See mm/mm.go for Mappable's place in the lock order. // // All Mappable methods have the following preconditions: -// * usermem.AddrRanges and MappableRanges must be non-empty (Length() != 0). -// * usermem.Addrs and Mappable offsets must be page-aligned. +// * hostarch.AddrRanges and MappableRanges must be non-empty (Length() != 0). +// * hostarch.Addrs and Mappable offsets must be page-aligned. type Mappable interface { // AddMapping notifies the Mappable of a mapping from addresses ar in ms to // offsets [offset, offset+ar.Length()) in this Mappable. @@ -42,7 +42,7 @@ type Mappable interface { // lifetime of the mapping. // // Preconditions: offset+ar.Length() does not overflow. - AddMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error + AddMapping(ctx context.Context, ms MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error // RemoveMapping notifies the Mappable of the removal of a mapping from // addresses ar in ms to offsets [offset, offset+ar.Length()) in this @@ -52,7 +52,7 @@ type Mappable interface { // * offset+ar.Length() does not overflow. // * The removed mapping must exist. writable must match the // corresponding call to AddMapping. - RemoveMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) + RemoveMapping(ctx context.Context, ms MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) // CopyMapping notifies the Mappable of an attempt to copy a mapping in ms // from srcAR to dstAR. For most Mappables, this is equivalent to @@ -66,7 +66,7 @@ type Mappable interface { // * offset+srcAR.Length() and offset+dstAR.Length() do not overflow. // * The mapping at srcAR must exist. writable must match the // corresponding call to AddMapping. - CopyMapping(ctx context.Context, ms MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error + CopyMapping(ctx context.Context, ms MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error // Translate returns the Mappable's current mappings for at least the range // of offsets specified by required, and at most the range of offsets @@ -90,7 +90,7 @@ type Mappable interface { // synchronize with invalidation. // // Postconditions: See CheckTranslateResult. - Translate(ctx context.Context, required, optional MappableRange, at usermem.AccessType) ([]Translation, error) + Translate(ctx context.Context, required, optional MappableRange, at hostarch.AccessType) ([]Translation, error) // InvalidateUnsavable requests that the Mappable invalidate Translations // that cannot be preserved across save/restore. @@ -113,7 +113,7 @@ type Translation struct { // Perms is the set of permissions for which platform.AddressSpace.MapFile // and platform.AddressSpace.MapInternal on this Translation is permitted. - Perms usermem.AccessType + Perms hostarch.AccessType } // FileRange returns the FileRange represented by t. @@ -125,18 +125,18 @@ func (t Translation) FileRange() FileRange { // postconditions for Mappable.Translate(required, optional, at). // // Preconditions: Same as Mappable.Translate. -func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error { +func CheckTranslateResult(required, optional MappableRange, at hostarch.AccessType, ts []Translation, terr error) error { // Verify that the inputs to Mappable.Translate were valid. if !required.WellFormed() || required.Length() == 0 { panic(fmt.Sprintf("invalid required range: %v", required)) } - if !usermem.Addr(required.Start).IsPageAligned() || !usermem.Addr(required.End).IsPageAligned() { + if !hostarch.Addr(required.Start).IsPageAligned() || !hostarch.Addr(required.End).IsPageAligned() { panic(fmt.Sprintf("unaligned required range: %v", required)) } if !optional.IsSupersetOf(required) { panic(fmt.Sprintf("optional range %v is not a superset of required range %v", optional, required)) } - if !usermem.Addr(optional.Start).IsPageAligned() || !usermem.Addr(optional.End).IsPageAligned() { + if !hostarch.Addr(optional.Start).IsPageAligned() || !hostarch.Addr(optional.End).IsPageAligned() { panic(fmt.Sprintf("unaligned optional range: %v", optional)) } @@ -148,13 +148,13 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp if !t.Source.WellFormed() || t.Source.Length() == 0 { return fmt.Errorf("Translation %+v has invalid Source", t) } - if !usermem.Addr(t.Source.Start).IsPageAligned() || !usermem.Addr(t.Source.End).IsPageAligned() { + if !hostarch.Addr(t.Source.Start).IsPageAligned() || !hostarch.Addr(t.Source.End).IsPageAligned() { return fmt.Errorf("Translation %+v has unaligned Source", t) } if t.File == nil { return fmt.Errorf("Translation %+v has nil File", t) } - if !usermem.Addr(t.Offset).IsPageAligned() { + if !hostarch.Addr(t.Offset).IsPageAligned() { return fmt.Errorf("Translation %+v has unaligned Offset", t) } // Translations must be contiguous and in increasing order of @@ -210,7 +210,7 @@ func (mr MappableRange) String() string { return fmt.Sprintf("[%#x, %#x)", mr.Start, mr.End) } -// MappingSpace represents a mutable mapping from usermem.Addrs to (Mappable, +// MappingSpace represents a mutable mapping from hostarch.Addrs to (Mappable, // uint64 offset) pairs. type MappingSpace interface { // Invalidate is called to notify the MappingSpace that values returned by @@ -223,7 +223,7 @@ type MappingSpace interface { // Preconditions: // * ar.Length() != 0. // * ar must be page-aligned. - Invalidate(ar usermem.AddrRange, opts InvalidateOpts) + Invalidate(ar hostarch.AddrRange, opts InvalidateOpts) } // InvalidateOpts holds options to MappingSpace.Invalidate. @@ -321,7 +321,7 @@ type MMapOpts struct { Offset uint64 // Addr is the suggested address for the mapping. - Addr usermem.Addr + Addr hostarch.Addr // Fixed specifies whether this is a fixed mapping (it must be located at // Addr). @@ -338,7 +338,7 @@ type MMapOpts struct { Map32Bit bool // Perms is the set of permissions to the applied to this mapping. - Perms usermem.AccessType + Perms hostarch.AccessType // MaxPerms limits the set of permissions that may ever apply to this // mapping. If Mappable is not nil, all memmap.Translations returned by @@ -346,7 +346,7 @@ type MMapOpts struct { // // Preconditions: MaxAccessType should be an effective AccessType, as // access cannot be limited beyond effective AccessTypes. - MaxPerms usermem.AccessType + MaxPerms hostarch.AccessType // Private is true if writes to the mapping should be propagated to a copy // that is exclusive to the MemoryManager. @@ -375,6 +375,11 @@ type MMapOpts struct { // // If Force is true, Unmap and Fixed must be true. Force bool + + // SentryOwnedContent indicates the sentry exclusively controls the + // underlying memory backing the mapping thus the memory content is + // guaranteed not to be modified outside the sentry's purview. + SentryOwnedContent bool } // File represents a host file that may be mapped into an platform.AddressSpace. @@ -410,7 +415,7 @@ type File interface { // // Postconditions: The returned mapping is valid as long as at least one // reference is held on the mapped pages. - MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) + MapInternal(fr FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) // FD returns the file descriptor represented by the File. // diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 6dbeccfe2..b417c2da7 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -28,14 +28,14 @@ go_template_instance( "trackGaps": "1", }, imports = { - "usermem": "gvisor.dev/gvisor/pkg/usermem", + "hostarch": "gvisor.dev/gvisor/pkg/hostarch", }, package = "mm", prefix = "vma", template = "//pkg/segment:generic_set", types = { - "Key": "usermem.Addr", - "Range": "usermem.AddrRange", + "Key": "hostarch.Addr", + "Range": "hostarch.AddrRange", "Value": "vma", "Functions": "vmaSetFunctions", }, @@ -48,14 +48,14 @@ go_template_instance( "minDegree": "8", }, imports = { - "usermem": "gvisor.dev/gvisor/pkg/usermem", + "hostarch": "gvisor.dev/gvisor/pkg/hostarch", }, package = "mm", prefix = "pma", template = "//pkg/segment:generic_set", types = { - "Key": "usermem.Addr", - "Range": "usermem.AddrRange", + "Key": "hostarch.Addr", + "Range": "hostarch.AddrRange", "Value": "pma", "Functions": "pmaSetFunctions", }, @@ -125,6 +125,7 @@ go_library( "//pkg/abi/linux", "//pkg/atomicbitops", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", @@ -155,6 +156,7 @@ go_test( library = ":mm", deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/arch", "//pkg/sentry/contexttest", "//pkg/sentry/limits", diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index a93e76c75..534e0e957 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -19,8 +19,8 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/usermem" ) // AddressSpace returns the platform.AddressSpace bound to mm. @@ -172,17 +172,17 @@ func (mm *MemoryManager) Deactivate() { // * ar.Length() != 0. // * ar must be page-aligned. // * pseg == mm.pmas.LowerBoundSegment(ar.Start). -func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, precommit bool) error { +func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar hostarch.AddrRange, precommit bool) error { // By default, map entire pmas at a time, under the assumption that there // is no cost to mapping more of a pma than necessary. - mapAR := usermem.AddrRange{0, ^usermem.Addr(usermem.PageSize - 1)} + mapAR := hostarch.AddrRange{0, ^hostarch.Addr(hostarch.PageSize - 1)} if precommit { // When explicitly precommitting, only map ar, since overmapping may // incur unexpected resource usage. mapAR = ar } else if mapUnit := mm.p.MapUnit(); mapUnit != 0 { // Limit the range we map to ar, aligned to mapUnit. - mapMask := usermem.Addr(mapUnit - 1) + mapMask := hostarch.Addr(mapUnit - 1) mapAR.Start = ar.Start &^ mapMask // If rounding ar.End up overflows, just keep the existing mapAR.End. if end := (ar.End + mapMask) &^ mapMask; end >= ar.End { @@ -218,7 +218,7 @@ func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, pre // unmapASLocked removes all AddressSpace mappings for addresses in ar. // // Preconditions: mm.activeMu must be locked. -func (mm *MemoryManager) unmapASLocked(ar usermem.AddrRange) { +func (mm *MemoryManager) unmapASLocked(ar hostarch.AddrRange) { if mm.as == nil { // No AddressSpace? Force all mappings to be unmapped on the next // Activate. diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 5ab2ef79f..346866d3c 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -17,6 +17,7 @@ package mm import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" @@ -83,7 +84,7 @@ func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) // the same address. Then it would be unmapping memory that it doesn't own. // This is, however, the way Linux implements AIO. Keeps the same [weird] // semantics in case anyone relies on it. - mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) + mm.MUnmap(ctx, hostarch.Addr(id), aioRingBufferSize) delete(mm.aioManager.contexts, id) aioCtx.destroy() @@ -259,7 +260,7 @@ type aioMappable struct { fr memmap.FileRange } -var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp()) +var aioRingBufferSize = uint64(hostarch.Addr(linux.AIORingSize).MustRoundUp()) func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) { fr, err := mfp.MemoryFile().Allocate(aioRingBufferSize, usage.Anonymous) @@ -300,7 +301,7 @@ func (m *aioMappable) Msync(ctx context.Context, mr memmap.MappableRange) error } // AddMapping implements memmap.Mappable.AddMapping. -func (m *aioMappable) AddMapping(_ context.Context, _ memmap.MappingSpace, ar usermem.AddrRange, offset uint64, _ bool) error { +func (m *aioMappable) AddMapping(_ context.Context, _ memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, _ bool) error { // Don't allow mappings to be expanded (in Linux, fs/aio.c:aio_ring_mmap() // sets VM_DONTEXPAND). if offset != 0 || uint64(ar.Length()) != aioRingBufferSize { @@ -310,11 +311,11 @@ func (m *aioMappable) AddMapping(_ context.Context, _ memmap.MappingSpace, ar us } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (m *aioMappable) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) { +func (m *aioMappable) RemoveMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, uint64, bool) { } // CopyMapping implements memmap.Mappable.CopyMapping. -func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, _ bool) error { +func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, _ bool) error { // Don't allow mappings to be expanded (in Linux, fs/aio.c:aio_ring_mmap() // sets VM_DONTEXPAND). if offset != 0 || uint64(dstAR.Length()) != aioRingBufferSize { @@ -346,7 +347,7 @@ func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, s } // Translate implements memmap.Mappable.Translate. -func (m *aioMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (m *aioMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { var err error if required.End > m.fr.Length() { err = &memmap.BusError{syserror.EFAULT} @@ -357,7 +358,7 @@ func (m *aioMappable) Translate(ctx context.Context, required, optional memmap.M Source: source, File: m.mfp.MemoryFile(), Offset: m.fr.Start + source.Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }, }, err } @@ -389,8 +390,8 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint // Linux uses "do_mmap_pgoff(..., PROT_READ | PROT_WRITE, ...)" in // fs/aio.c:aio_setup_ring(). Since we don't implement AIO_RING_MAGIC, // user mode should not write to this page. - Perms: usermem.Read, - MaxPerms: usermem.Read, + Perms: hostarch.Read, + MaxPerms: hostarch.Read, }) if err != nil { return 0, err @@ -435,6 +436,6 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC // bytes from id). func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool { var buf [4]byte - _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) + _, err := mm.CopyIn(ctx, hostarch.Addr(id), buf[:], usermem.IOOpts{}) return err == nil } diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go index a8ac48080..16f318ab3 100644 --- a/pkg/sentry/mm/io.go +++ b/pkg/sentry/mm/io.go @@ -16,6 +16,7 @@ package mm import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/syserror" @@ -60,11 +61,11 @@ const ( rwMapMinBytes = 512 ) -// CheckIORange is similar to usermem.Addr.ToRange, but applies bounds checks +// CheckIORange is similar to hostarch.Addr.ToRange, but applies bounds checks // consistent with Linux's arch/x86/include/asm/uaccess.h:access_ok(). // // Preconditions: length >= 0. -func (mm *MemoryManager) CheckIORange(addr usermem.Addr, length int64) (usermem.AddrRange, bool) { +func (mm *MemoryManager) CheckIORange(addr hostarch.Addr, length int64) (hostarch.AddrRange, bool) { // Note that access_ok() constrains end even if length == 0. ar, ok := addr.ToRange(uint64(length)) return ar, (ok && ar.End <= mm.layout.MaxAddr) @@ -72,7 +73,7 @@ func (mm *MemoryManager) CheckIORange(addr usermem.Addr, length int64) (usermem. // checkIOVec applies bound checks consistent with Linux's // arch/x86/include/asm/uaccess.h:access_ok() to ars. -func (mm *MemoryManager) checkIOVec(ars usermem.AddrRangeSeq) bool { +func (mm *MemoryManager) checkIOVec(ars hostarch.AddrRangeSeq) bool { for !ars.IsEmpty() { ar := ars.Head() if _, ok := mm.CheckIORange(ar.Start, int64(ar.Length())); !ok { @@ -100,7 +101,7 @@ func translateIOError(ctx context.Context, err error) error { } // CopyOut implements usermem.IO.CopyOut. -func (mm *MemoryManager) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { +func (mm *MemoryManager) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts usermem.IOOpts) (int, error) { ar, ok := mm.CheckIORange(addr, int64(len(src))) if !ok { return 0, syserror.EFAULT @@ -116,24 +117,24 @@ func (mm *MemoryManager) CopyOut(ctx context.Context, addr usermem.Addr, src []b } // Go through internal mappings. - n64, err := mm.withInternalMappings(ctx, ar, usermem.Write, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { + n64, err := mm.withInternalMappings(ctx, ar, hostarch.Write, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { n, err := safemem.CopySeq(ims, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src))) return n, translateIOError(ctx, err) }) return int(n64), err } -func (mm *MemoryManager) asCopyOut(ctx context.Context, addr usermem.Addr, src []byte) (int, error) { +func (mm *MemoryManager) asCopyOut(ctx context.Context, addr hostarch.Addr, src []byte) (int, error) { var done int for { - n, err := mm.as.CopyOut(addr+usermem.Addr(done), src[done:]) + n, err := mm.as.CopyOut(addr+hostarch.Addr(done), src[done:]) done += n if err == nil { return done, nil } if f, ok := err.(platform.SegmentationFault); ok { ar, _ := addr.ToRange(uint64(len(src))) - if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Write); err != nil { + if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.Write); err != nil { return done, err } continue @@ -143,7 +144,7 @@ func (mm *MemoryManager) asCopyOut(ctx context.Context, addr usermem.Addr, src [ } // CopyIn implements usermem.IO.CopyIn. -func (mm *MemoryManager) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { +func (mm *MemoryManager) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts usermem.IOOpts) (int, error) { ar, ok := mm.CheckIORange(addr, int64(len(dst))) if !ok { return 0, syserror.EFAULT @@ -159,24 +160,24 @@ func (mm *MemoryManager) CopyIn(ctx context.Context, addr usermem.Addr, dst []by } // Go through internal mappings. - n64, err := mm.withInternalMappings(ctx, ar, usermem.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { + n64, err := mm.withInternalMappings(ctx, ar, hostarch.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { n, err := safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), ims) return n, translateIOError(ctx, err) }) return int(n64), err } -func (mm *MemoryManager) asCopyIn(ctx context.Context, addr usermem.Addr, dst []byte) (int, error) { +func (mm *MemoryManager) asCopyIn(ctx context.Context, addr hostarch.Addr, dst []byte) (int, error) { var done int for { - n, err := mm.as.CopyIn(addr+usermem.Addr(done), dst[done:]) + n, err := mm.as.CopyIn(addr+hostarch.Addr(done), dst[done:]) done += n if err == nil { return done, nil } if f, ok := err.(platform.SegmentationFault); ok { ar, _ := addr.ToRange(uint64(len(dst))) - if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Read); err != nil { + if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.Read); err != nil { return done, err } continue @@ -186,7 +187,7 @@ func (mm *MemoryManager) asCopyIn(ctx context.Context, addr usermem.Addr, dst [] } // ZeroOut implements usermem.IO.ZeroOut. -func (mm *MemoryManager) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { +func (mm *MemoryManager) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { ar, ok := mm.CheckIORange(addr, toZero) if !ok { return 0, syserror.EFAULT @@ -202,23 +203,23 @@ func (mm *MemoryManager) ZeroOut(ctx context.Context, addr usermem.Addr, toZero } // Go through internal mappings. - return mm.withInternalMappings(ctx, ar, usermem.Write, opts.IgnorePermissions, func(dsts safemem.BlockSeq) (uint64, error) { + return mm.withInternalMappings(ctx, ar, hostarch.Write, opts.IgnorePermissions, func(dsts safemem.BlockSeq) (uint64, error) { n, err := safemem.ZeroSeq(dsts) return n, translateIOError(ctx, err) }) } -func (mm *MemoryManager) asZeroOut(ctx context.Context, addr usermem.Addr, toZero int64) (int64, error) { +func (mm *MemoryManager) asZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64) (int64, error) { var done int64 for { - n, err := mm.as.ZeroOut(addr+usermem.Addr(done), uintptr(toZero-done)) + n, err := mm.as.ZeroOut(addr+hostarch.Addr(done), uintptr(toZero-done)) done += int64(n) if err == nil { return done, nil } if f, ok := err.(platform.SegmentationFault); ok { ar, _ := addr.ToRange(uint64(toZero)) - if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Write); err != nil { + if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.Write); err != nil { return done, err } continue @@ -228,7 +229,7 @@ func (mm *MemoryManager) asZeroOut(ctx context.Context, addr usermem.Addr, toZer } // CopyOutFrom implements usermem.IO.CopyOutFrom. -func (mm *MemoryManager) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { +func (mm *MemoryManager) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { if !mm.checkIOVec(ars) { return 0, syserror.EFAULT } @@ -269,11 +270,11 @@ func (mm *MemoryManager) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeS } // Go through internal mappings. - return mm.withVecInternalMappings(ctx, ars, usermem.Write, opts.IgnorePermissions, src.ReadToBlocks) + return mm.withVecInternalMappings(ctx, ars, hostarch.Write, opts.IgnorePermissions, src.ReadToBlocks) } // CopyInTo implements usermem.IO.CopyInTo. -func (mm *MemoryManager) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { +func (mm *MemoryManager) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { if !mm.checkIOVec(ars) { return 0, syserror.EFAULT } @@ -306,11 +307,11 @@ func (mm *MemoryManager) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, } // Go through internal mappings. - return mm.withVecInternalMappings(ctx, ars, usermem.Read, opts.IgnorePermissions, dst.WriteFromBlocks) + return mm.withVecInternalMappings(ctx, ars, hostarch.Read, opts.IgnorePermissions, dst.WriteFromBlocks) } // SwapUint32 implements usermem.IO.SwapUint32. -func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) { +func (mm *MemoryManager) SwapUint32(ctx context.Context, addr hostarch.Addr, new uint32, opts usermem.IOOpts) (uint32, error) { ar, ok := mm.CheckIORange(addr, 4) if !ok { return 0, syserror.EFAULT @@ -324,7 +325,7 @@ func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new return old, nil } if f, ok := err.(platform.SegmentationFault); ok { - if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.ReadWrite); err != nil { + if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.ReadWrite); err != nil { return 0, err } continue @@ -335,7 +336,7 @@ func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new // Go through internal mappings. var old uint32 - _, err := mm.withInternalMappings(ctx, ar, usermem.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { + _, err := mm.withInternalMappings(ctx, ar, hostarch.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { if ims.NumBlocks() != 1 || ims.NumBytes() != 4 { // Atomicity is unachievable across mappings. return 0, syserror.EFAULT @@ -353,7 +354,7 @@ func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new } // CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32. -func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) { +func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr hostarch.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) { ar, ok := mm.CheckIORange(addr, 4) if !ok { return 0, syserror.EFAULT @@ -367,7 +368,7 @@ func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem. return prev, nil } if f, ok := err.(platform.SegmentationFault); ok { - if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.ReadWrite); err != nil { + if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.ReadWrite); err != nil { return 0, err } continue @@ -378,7 +379,7 @@ func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem. // Go through internal mappings. var prev uint32 - _, err := mm.withInternalMappings(ctx, ar, usermem.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { + _, err := mm.withInternalMappings(ctx, ar, hostarch.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { if ims.NumBlocks() != 1 || ims.NumBytes() != 4 { // Atomicity is unachievable across mappings. return 0, syserror.EFAULT @@ -396,7 +397,7 @@ func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem. } // LoadUint32 implements usermem.IO.LoadUint32. -func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) { +func (mm *MemoryManager) LoadUint32(ctx context.Context, addr hostarch.Addr, opts usermem.IOOpts) (uint32, error) { ar, ok := mm.CheckIORange(addr, 4) if !ok { return 0, syserror.EFAULT @@ -410,7 +411,7 @@ func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts return val, nil } if f, ok := err.(platform.SegmentationFault); ok { - if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Read); err != nil { + if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.Read); err != nil { return 0, err } continue @@ -421,7 +422,7 @@ func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts // Go through internal mappings. var val uint32 - _, err := mm.withInternalMappings(ctx, ar, usermem.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { + _, err := mm.withInternalMappings(ctx, ar, hostarch.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) { if ims.NumBlocks() != 1 || ims.NumBytes() != 4 { // Atomicity is unachievable across mappings. return 0, syserror.EFAULT @@ -445,11 +446,11 @@ func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts // * mm.as != nil. // * ioar.Length() != 0. // * ioar.Contains(addr). -func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr usermem.Addr, ioar usermem.AddrRange, at usermem.AccessType) error { +func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr hostarch.Addr, ioar hostarch.AddrRange, at hostarch.AccessType) error { // Try to map all remaining pages in the I/O operation. This RoundUp can't // overflow because otherwise it would have been caught by CheckIORange. end, _ := ioar.End.RoundUp() - ar := usermem.AddrRange{addr.RoundDown(), end} + ar := hostarch.AddrRange{addr.RoundDown(), end} // Don't bother trying existingPMAsLocked; in most cases, if we did have // existing pmas, we wouldn't have faulted. @@ -498,7 +499,7 @@ func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr usermem.Addr, // more useful for usermem.IO methods. // // Preconditions: 0 < ar.Length() <= math.MaxInt64. -func (mm *MemoryManager) withInternalMappings(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { +func (mm *MemoryManager) withInternalMappings(ctx context.Context, ar hostarch.AddrRange, at hostarch.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { // If pmas are already available, we can do IO without touching mm.vmas or // mm.mappingMu. mm.activeMu.RLock() @@ -567,7 +568,7 @@ func (mm *MemoryManager) withInternalMappings(ctx context.Context, ar usermem.Ad // internal mappings for the subset of ars for which this property holds. // // Preconditions: !ars.IsEmpty(). -func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { +func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars hostarch.AddrRangeSeq, at hostarch.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { // withInternalMappings is faster than withVecInternalMappings because of // iterator plumbing (this isn't generally practical in the vector case due // to iterator invalidation between AddrRanges). Use it if possible. @@ -630,12 +631,12 @@ func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars userme // truncatedAddrRangeSeq returns a copy of ars, but with the end truncated to // at most address end on AddrRange arsit.Head(). It is used in vector I/O paths to -// truncate usermem.AddrRangeSeq when errors occur. +// truncate hostarch.AddrRangeSeq when errors occur. // // Preconditions: // * !arsit.IsEmpty(). // * end <= arsit.Head().End. -func truncatedAddrRangeSeq(ars, arsit usermem.AddrRangeSeq, end usermem.Addr) usermem.AddrRangeSeq { +func truncatedAddrRangeSeq(ars, arsit hostarch.AddrRangeSeq, end hostarch.Addr) hostarch.AddrRangeSeq { ar := arsit.Head() if end <= ar.Start { return ars.TakeFirst64(ars.NumBytes() - arsit.NumBytes()) diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 120707429..a79ef9223 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -19,12 +19,12 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/usermem" ) // NewMemoryManager returns a new MemoryManager with no mappings and 1 user. @@ -139,7 +139,7 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { } srcvseg := mm.vmas.FirstSegment() dstpgap := mm2.pmas.FirstGap() - var unmapAR usermem.AddrRange + var unmapAR hostarch.AddrRange for srcpseg := mm.pmas.FirstSegment(); srcpseg.Ok(); srcpseg = srcpseg.NextSegment() { pma := srcpseg.ValuePtr() if !pma.private { diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go index 0cfd60f6c..28c5fead9 100644 --- a/pkg/sentry/mm/metadata.go +++ b/pkg/sentry/mm/metadata.go @@ -16,9 +16,9 @@ package mm import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsbridge" - "gvisor.dev/gvisor/pkg/usermem" ) // Dumpability describes if and how core dumps should be created. @@ -54,14 +54,14 @@ func (mm *MemoryManager) SetDumpability(d Dumpability) { // ArgvStart returns the start of the application argument vector. // // There is no guarantee that this value is sensible w.r.t. ArgvEnd. -func (mm *MemoryManager) ArgvStart() usermem.Addr { +func (mm *MemoryManager) ArgvStart() hostarch.Addr { mm.metadataMu.Lock() defer mm.metadataMu.Unlock() return mm.argv.Start } // SetArgvStart sets the start of the application argument vector. -func (mm *MemoryManager) SetArgvStart(a usermem.Addr) { +func (mm *MemoryManager) SetArgvStart(a hostarch.Addr) { mm.metadataMu.Lock() defer mm.metadataMu.Unlock() mm.argv.Start = a @@ -70,14 +70,14 @@ func (mm *MemoryManager) SetArgvStart(a usermem.Addr) { // ArgvEnd returns the end of the application argument vector. // // There is no guarantee that this value is sensible w.r.t. ArgvStart. -func (mm *MemoryManager) ArgvEnd() usermem.Addr { +func (mm *MemoryManager) ArgvEnd() hostarch.Addr { mm.metadataMu.Lock() defer mm.metadataMu.Unlock() return mm.argv.End } // SetArgvEnd sets the end of the application argument vector. -func (mm *MemoryManager) SetArgvEnd(a usermem.Addr) { +func (mm *MemoryManager) SetArgvEnd(a hostarch.Addr) { mm.metadataMu.Lock() defer mm.metadataMu.Unlock() mm.argv.End = a @@ -86,14 +86,14 @@ func (mm *MemoryManager) SetArgvEnd(a usermem.Addr) { // EnvvStart returns the start of the application environment vector. // // There is no guarantee that this value is sensible w.r.t. EnvvEnd. -func (mm *MemoryManager) EnvvStart() usermem.Addr { +func (mm *MemoryManager) EnvvStart() hostarch.Addr { mm.metadataMu.Lock() defer mm.metadataMu.Unlock() return mm.envv.Start } // SetEnvvStart sets the start of the application environment vector. -func (mm *MemoryManager) SetEnvvStart(a usermem.Addr) { +func (mm *MemoryManager) SetEnvvStart(a hostarch.Addr) { mm.metadataMu.Lock() defer mm.metadataMu.Unlock() mm.envv.Start = a @@ -102,14 +102,14 @@ func (mm *MemoryManager) SetEnvvStart(a usermem.Addr) { // EnvvEnd returns the end of the application environment vector. // // There is no guarantee that this value is sensible w.r.t. EnvvStart. -func (mm *MemoryManager) EnvvEnd() usermem.Addr { +func (mm *MemoryManager) EnvvEnd() hostarch.Addr { mm.metadataMu.Lock() defer mm.metadataMu.Unlock() return mm.envv.End } // SetEnvvEnd sets the end of the application environment vector. -func (mm *MemoryManager) SetEnvvEnd(a usermem.Addr) { +func (mm *MemoryManager) SetEnvvEnd(a hostarch.Addr) { mm.metadataMu.Lock() defer mm.metadataMu.Unlock() mm.envv.End = a diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 92cc87d84..57969b26c 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -36,6 +36,7 @@ package mm import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -43,7 +44,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // MemoryManager implements a virtual address space. @@ -97,7 +97,7 @@ type MemoryManager struct { // binary into the mm. // // brk is protected by mappingMu. - brk usermem.AddrRange + brk hostarch.AddrRange // usageAS is vmas.Span(), cached to accelerate RLIMIT_AS checks. // @@ -198,14 +198,14 @@ type MemoryManager struct { // requirements apply to argv; we do not require that argv.WellFormed(). // // argv is protected by metadataMu. - argv usermem.AddrRange + argv hostarch.AddrRange // envv is the application envv. This is set up by the loader and may be // modified by prctl(PR_SET_MM_ENV_START/PR_SET_MM_ENV_END). No // requirements apply to envv; we do not require that envv.WellFormed(). // // envv is protected by metadataMu. - envv usermem.AddrRange + envv hostarch.AddrRange // auxv is the ELF's auxiliary vector. // @@ -268,20 +268,20 @@ type vma struct { // realPerms are the memory permissions on this vma, as defined by the // application. - realPerms usermem.AccessType `state:".(int)"` + realPerms hostarch.AccessType `state:".(int)"` // effectivePerms are the memory permissions on this vma which are // actually used to control access. // // Invariant: effectivePerms == realPerms.Effective(). - effectivePerms usermem.AccessType `state:"manual"` + effectivePerms hostarch.AccessType `state:"manual"` // maxPerms limits the set of permissions that may ever apply to this // memory, as well as accesses for which usermem.IOOpts.IgnorePermissions // is true (e.g. ptrace(PTRACE_POKEDATA)). // // Invariant: maxPerms == maxPerms.Effective(). - maxPerms usermem.AccessType `state:"manual"` + maxPerms hostarch.AccessType `state:"manual"` // private is true if this is a MAP_PRIVATE mapping, such that writes to // the mapping are propagated to a copy. @@ -421,8 +421,8 @@ type pma struct { off uint64 // translatePerms is the permissions returned by memmap.Mappable.Translate. - // If private is true, translatePerms is usermem.AnyAccess. - translatePerms usermem.AccessType + // If private is true, translatePerms is hostarch.AnyAccess. + translatePerms hostarch.AccessType // effectivePerms is the permissions allowed for non-ignorePermissions // accesses. maxPerms is the permissions allowed for ignorePermissions @@ -432,8 +432,8 @@ type pma struct { // // These are stored in the pma so that the IO implementation can avoid // iterating mm.vmas when pmas already exist. - effectivePerms usermem.AccessType - maxPerms usermem.AccessType + effectivePerms hostarch.AccessType + maxPerms hostarch.AccessType // needCOW is true if writes to the mapping must be propagated to a copy. needCOW bool @@ -465,7 +465,7 @@ type privateRefs struct { } type invalidateArgs struct { - ar usermem.AddrRange + ar hostarch.AddrRange opts memmap.InvalidateOpts } diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index bc53bd41e..1304b0a2f 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -18,6 +18,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/limits" @@ -51,7 +52,7 @@ func TestUsageASUpdates(t *testing.T) { defer mm.DecUsers(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 2 * usermem.PageSize, + Length: 2 * hostarch.PageSize, Private: true, }) if err != nil { @@ -62,7 +63,7 @@ func TestUsageASUpdates(t *testing.T) { t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage) } - mm.MUnmap(ctx, addr, usermem.PageSize) + mm.MUnmap(ctx, addr, hostarch.PageSize) realUsage = mm.realUsageAS() if mm.usageAS != realUsage { t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage) @@ -86,10 +87,10 @@ func TestDataASUpdates(t *testing.T) { defer mm.DecUsers(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 3 * usermem.PageSize, + Length: 3 * hostarch.PageSize, Private: true, - Perms: usermem.Write, - MaxPerms: usermem.AnyAccess, + Perms: hostarch.Write, + MaxPerms: hostarch.AnyAccess, }) if err != nil { t.Fatalf("MMap got err %v want nil", err) @@ -102,19 +103,19 @@ func TestDataASUpdates(t *testing.T) { t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) } - mm.MUnmap(ctx, addr, usermem.PageSize) + mm.MUnmap(ctx, addr, hostarch.PageSize) realDataAS = mm.realDataAS() if mm.dataAS != realDataAS { t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) } - mm.MProtect(addr+usermem.PageSize, usermem.PageSize, usermem.Read, false) + mm.MProtect(addr+hostarch.PageSize, hostarch.PageSize, hostarch.Read, false) realDataAS = mm.realDataAS() if mm.dataAS != realDataAS { t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) } - mm.MRemap(ctx, addr+2*usermem.PageSize, usermem.PageSize, 2*usermem.PageSize, MRemapOpts{ + mm.MRemap(ctx, addr+2*hostarch.PageSize, hostarch.PageSize, 2*hostarch.PageSize, MRemapOpts{ Move: MRemapMayMove, }) realDataAS = mm.realDataAS() @@ -133,7 +134,7 @@ func TestBrkDataLimitUpdates(t *testing.T) { // Try to extend the brk by one page and expect doing so to fail. oldBrk, _ := mm.Brk(ctx, 0) - if newBrk, _ := mm.Brk(ctx, oldBrk+usermem.PageSize); newBrk != oldBrk { + if newBrk, _ := mm.Brk(ctx, oldBrk+hostarch.PageSize); newBrk != oldBrk { t.Errorf("brk() increased data segment above RLIMIT_DATA (old brk = %#x, new brk = %#x", oldBrk, newBrk) } } @@ -145,10 +146,10 @@ func TestIOAfterUnmap(t *testing.T) { defer mm.DecUsers(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: usermem.PageSize, + Length: hostarch.PageSize, Private: true, - Perms: usermem.Read, - MaxPerms: usermem.AnyAccess, + Perms: hostarch.Read, + MaxPerms: hostarch.AnyAccess, }) if err != nil { t.Fatalf("MMap got err %v want nil", err) @@ -164,7 +165,7 @@ func TestIOAfterUnmap(t *testing.T) { t.Errorf("CopyIn got %d want 1", n) } - err = mm.MUnmap(ctx, addr, usermem.PageSize) + err = mm.MUnmap(ctx, addr, hostarch.PageSize) if err != nil { t.Fatalf("MUnmap got err %v want nil", err) } @@ -185,10 +186,10 @@ func TestIOAfterMProtect(t *testing.T) { defer mm.DecUsers(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: usermem.PageSize, + Length: hostarch.PageSize, Private: true, - Perms: usermem.ReadWrite, - MaxPerms: usermem.AnyAccess, + Perms: hostarch.ReadWrite, + MaxPerms: hostarch.AnyAccess, }) if err != nil { t.Fatalf("MMap got err %v want nil", err) @@ -204,7 +205,7 @@ func TestIOAfterMProtect(t *testing.T) { t.Errorf("CopyOut got %d want 1", n) } - err = mm.MProtect(addr, usermem.PageSize, usermem.Read, false) + err = mm.MProtect(addr, hostarch.PageSize, hostarch.Read, false) if err != nil { t.Errorf("MProtect got err %v want nil", err) } diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 7e5f7de64..5583f62b2 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -18,12 +18,12 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // existingPMAsLocked checks that pmas exist for all addresses in ar, and @@ -34,7 +34,7 @@ import ( // Preconditions: // * mm.activeMu must be locked. // * ar.Length() != 0. -func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator { +func (mm *MemoryManager) existingPMAsLocked(ar hostarch.AddrRange, at hostarch.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -70,7 +70,7 @@ func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.Acc // and support access of type (at, ignorePermissions). // // Preconditions: mm.activeMu must be locked. -func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) bool { +func (mm *MemoryManager) existingVecPMAsLocked(ars hostarch.AddrRangeSeq, at hostarch.AccessType, ignorePermissions bool, needInternalMappings bool) bool { for ; !ars.IsEmpty(); ars = ars.Tail() { if ar := ars.Head(); ar.Length() != 0 && !mm.existingPMAsLocked(ar, at, ignorePermissions, needInternalMappings).Ok() { return false @@ -98,7 +98,7 @@ func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at user // * vseg.Range().Contains(ar.Start). // * vmas must exist for all addresses in ar, and support accesses of type at // (i.e. permission checks must have been performed against vmas). -func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) { +func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, at hostarch.AccessType) (pmaIterator, pmaGapIterator, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -118,7 +118,7 @@ func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar end = ar.End.RoundDown() alignerr = syserror.EFAULT } - ar = usermem.AddrRange{ar.Start.RoundDown(), end} + ar = hostarch.AddrRange{ar.Start.RoundDown(), end} pstart, pend, perr := mm.getPMAsInternalLocked(ctx, vseg, ar, at) if pend.Start() <= ar.Start { @@ -145,7 +145,7 @@ func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar // * mm.activeMu must be locked for writing. // * vmas must exist for all addresses in ars, and support accesses of type at // (i.e. permission checks must have been performed against vmas). -func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType) (usermem.AddrRangeSeq, error) { +func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars hostarch.AddrRangeSeq, at hostarch.AccessType) (hostarch.AddrRangeSeq, error) { for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() { ar := arsit.Head() if ar.Length() == 0 { @@ -164,7 +164,7 @@ func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrR end = ar.End.RoundDown() alignerr = syserror.EFAULT } - ar = usermem.AddrRange{ar.Start.RoundDown(), end} + ar = hostarch.AddrRange{ar.Start.RoundDown(), end} _, pend, perr := mm.getPMAsInternalLocked(ctx, mm.vmas.FindSegment(ar.Start), ar, at) if perr != nil { @@ -191,7 +191,7 @@ func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrR // // getPMAsInternalLocked is an implementation helper for getPMAsLocked and // getVecPMAsLocked; other clients should call one of those instead. -func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) { +func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, at hostarch.AccessType) (pmaIterator, pmaGapIterator, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -245,7 +245,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter pseg, pgap = mm.pmas.Insert(pgap, allocAR, pma{ file: mf, off: fr.Start, - translatePerms: usermem.AnyAccess, + translatePerms: hostarch.AnyAccess, effectivePerms: vma.effectivePerms, maxPerms: vma.maxPerms, // Since we just allocated this memory and have the @@ -335,7 +335,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter // Neither of these cases has enough spatial locality to // benefit from copying nearby pages, so if the vma is // executable, only copy the pages required. - var copyAR usermem.AddrRange + var copyAR hostarch.AddrRange if vseg.ValuePtr().effectivePerms.Execute { copyAR = pseg.Range().Intersect(ar) } else { @@ -366,7 +366,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter // Replace the pma with a copy in the part of the address // range where copying was successful. This doesn't change // RSS. - copyAR.End = copyAR.Start + usermem.Addr(fr.Length()) + copyAR.End = copyAR.Start + hostarch.Addr(fr.Length()) if copyAR != pseg.Range() { pseg = mm.pmas.Isolate(pseg, copyAR) pstart = pmaIterator{} // iterators invalidated @@ -380,7 +380,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter mf.IncRef(fr) oldpma.file = mf oldpma.off = fr.Start - oldpma.translatePerms = usermem.AnyAccess + oldpma.translatePerms = hostarch.AnyAccess oldpma.effectivePerms = vma.effectivePerms oldpma.maxPerms = vma.maxPerms oldpma.needCOW = false @@ -499,14 +499,14 @@ const ( // privateAllocUnit may reduce page faults by allowing fewer, larger pmas // to be mapped, but may result in larger amounts of wasted memory in the // presence of fragmentation. privateAllocUnit must be a power-of-2 - // multiple of usermem.PageSize. - privateAllocUnit = usermem.HugePageSize + // multiple of hostarch.PageSize. + privateAllocUnit = hostarch.HugePageSize privateAllocMask = privateAllocUnit - 1 ) -func privateAligned(ar usermem.AddrRange) usermem.AddrRange { - aligned := usermem.AddrRange{ar.Start &^ privateAllocMask, ar.End} +func privateAligned(ar hostarch.AddrRange) hostarch.AddrRange { + aligned := hostarch.AddrRange{ar.Start &^ privateAllocMask, ar.End} if end := (ar.End + privateAllocMask) &^ privateAllocMask; end >= ar.End { aligned.End = end } @@ -548,7 +548,7 @@ func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterat rseg := mm.privateRefs.refs.FindSegment(fr.Start) if rseg.Ok() && rseg.Value() == 1 && fr.End <= rseg.End() { pma.needCOW = false - // pma.private => pma.translatePerms == usermem.AnyAccess + // pma.private => pma.translatePerms == hostarch.AnyAccess vma := vseg.ValuePtr() pma.effectivePerms = vma.effectivePerms pma.maxPerms = vma.maxPerms @@ -558,7 +558,7 @@ func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterat } // Invalidate implements memmap.MappingSpace.Invalidate. -func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) { +func (mm *MemoryManager) Invalidate(ar hostarch.AddrRange, opts memmap.InvalidateOpts) { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -581,7 +581,7 @@ func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.Invalidate // * mm.activeMu must be locked for writing. // * ar.Length() != 0. // * ar must be page-aligned. -func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) { +func (mm *MemoryManager) invalidateLocked(ar hostarch.AddrRange, invalidatePrivate, invalidateShared bool) { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -627,7 +627,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat // Preconditions: // * ar.Length() != 0. // * ar must be page-aligned. -func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) { +func (mm *MemoryManager) Pin(ctx context.Context, ar hostarch.AddrRange, at hostarch.AccessType, ignorePermissions bool) ([]PinnedRange, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -683,7 +683,7 @@ func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at userm // PinnedRanges are returned by MemoryManager.Pin. type PinnedRange struct { // Source is the corresponding range of addresses. - Source usermem.AddrRange + Source hostarch.AddrRange // File is the mapped file. File memmap.File @@ -713,7 +713,7 @@ func Unpin(prs []PinnedRange) { // * !oldAR.Overlaps(newAR). // * mm.pmas.IsEmptyRange(newAR). // * oldAR and newAR must be page-aligned. -func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { +func (mm *MemoryManager) movePMAsLocked(oldAR, newAR hostarch.AddrRange) { if checkInvariants { if !oldAR.WellFormed() || oldAR.Length() == 0 || !oldAR.IsPageAligned() { panic(fmt.Sprintf("invalid oldAR: %v", oldAR)) @@ -731,7 +731,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { } type movedPMA struct { - oldAR usermem.AddrRange + oldAR hostarch.AddrRange pma pma } var movedPMAs []movedPMA @@ -751,7 +751,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { pgap := mm.pmas.FindGap(newAR.Start) for i := range movedPMAs { mpma := &movedPMAs[i] - pmaNewAR := usermem.AddrRange{mpma.oldAR.Start + off, mpma.oldAR.End + off} + pmaNewAR := hostarch.AddrRange{mpma.oldAR.Start + off, mpma.oldAR.End + off} pgap = mm.pmas.Insert(pgap, pmaNewAR, mpma.pma).NextGap() } @@ -776,7 +776,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { // // Postconditions: getPMAInternalMappingsLocked does not invalidate iterators // into mm.pmas. -func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) (pmaGapIterator, error) { +func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar hostarch.AddrRange) (pmaGapIterator, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -808,7 +808,7 @@ func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar userm // // Postconditions: getVecPMAInternalMappingsLocked does not invalidate iterators // into mm.pmas. -func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSeq) (usermem.AddrRangeSeq, error) { +func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars hostarch.AddrRangeSeq) (hostarch.AddrRangeSeq, error) { for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() { ar := arsit.Head() if ar.Length() == 0 { @@ -829,7 +829,7 @@ func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSe // in ar. // * ar.Length() != 0. // * pseg.Range().Contains(ar.Start). -func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq { +func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar hostarch.AddrRange) safemem.BlockSeq { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -866,7 +866,7 @@ func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.Add // * mm.activeMu must be locked. // * Internal mappings must have been previously established for all addresses // in ars. -func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) safemem.BlockSeq { +func (mm *MemoryManager) vecInternalMappingsLocked(ars hostarch.AddrRangeSeq) safemem.BlockSeq { var ims []safemem.Block for ; !ars.IsEmpty(); ars = ars.Tail() { ar := ars.Head() @@ -931,7 +931,7 @@ func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) { // MemoryManager to reflect the insertion of a pma at ar. // // Preconditions: mm.activeMu must be locked for writing. -func (mm *MemoryManager) addRSSLocked(ar usermem.AddrRange) { +func (mm *MemoryManager) addRSSLocked(ar hostarch.AddrRange) { mm.curRSS += uint64(ar.Length()) if mm.curRSS > mm.maxRSS { mm.maxRSS = mm.curRSS @@ -942,19 +942,19 @@ func (mm *MemoryManager) addRSSLocked(ar usermem.AddrRange) { // reflect the removal of a pma at ar. // // Preconditions: mm.activeMu must be locked for writing. -func (mm *MemoryManager) removeRSSLocked(ar usermem.AddrRange) { +func (mm *MemoryManager) removeRSSLocked(ar hostarch.AddrRange) { mm.curRSS -= uint64(ar.Length()) } // pmaSetFunctions implements segment.Functions for pmaSet. type pmaSetFunctions struct{} -func (pmaSetFunctions) MinKey() usermem.Addr { +func (pmaSetFunctions) MinKey() hostarch.Addr { return 0 } -func (pmaSetFunctions) MaxKey() usermem.Addr { - return ^usermem.Addr(0) +func (pmaSetFunctions) MaxKey() hostarch.Addr { + return ^hostarch.Addr(0) } func (pmaSetFunctions) ClearValue(pma *pma) { @@ -962,7 +962,7 @@ func (pmaSetFunctions) ClearValue(pma *pma) { pma.internalMappings = safemem.BlockSeq{} } -func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRange, pma2 pma) (pma, bool) { +func (pmaSetFunctions) Merge(ar1 hostarch.AddrRange, pma1 pma, ar2 hostarch.AddrRange, pma2 pma) (pma, bool) { if pma1.file != pma2.file || pma1.off+uint64(ar1.Length()) != pma2.off || pma1.translatePerms != pma2.translatePerms || @@ -980,7 +980,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa return pma1, true } -func (pmaSetFunctions) Split(ar usermem.AddrRange, p pma, split usermem.Addr) (pma, pma) { +func (pmaSetFunctions) Split(ar hostarch.AddrRange, p pma, split hostarch.Addr) (pma, pma) { newlen1 := uint64(split - ar.Start) p2 := p p2.off += newlen1 @@ -997,7 +997,7 @@ func (pmaSetFunctions) Split(ar usermem.AddrRange, p pma, split usermem.Addr) (p // Preconditions: // * mm.activeMu must be locked. // * addr <= pgap.Start(). -func (mm *MemoryManager) findOrSeekPrevUpperBoundPMA(addr usermem.Addr, pgap pmaGapIterator) pmaIterator { +func (mm *MemoryManager) findOrSeekPrevUpperBoundPMA(addr hostarch.Addr, pgap pmaGapIterator) pmaIterator { if checkInvariants { if !pgap.Ok() { panic("terminal pma iterator") @@ -1045,7 +1045,7 @@ func (pseg pmaIterator) fileRange() memmap.FileRange { // Preconditions: // * pseg.Range().IsSupersetOf(ar). // * ar.Length != 0. -func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange { +func (pseg pmaIterator) fileRangeOf(ar hostarch.AddrRange) memmap.FileRange { if checkInvariants { if !pseg.Ok() { panic("terminal pma iterator") diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go index 73bfbea49..f1440e884 100644 --- a/pkg/sentry/mm/procfs.go +++ b/pkg/sentry/mm/procfs.go @@ -19,9 +19,9 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/usermem" ) const ( @@ -29,7 +29,7 @@ const ( // include/linux/kdev_t.h:MINORBITS devMinorBits = 20 - vsyscallEnd = usermem.Addr(0xffffffffff601000) + vsyscallEnd = hostarch.Addr(0xffffffffff601000) vsyscallMapsEntry = "ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0 [vsyscall]\n" vsyscallSmapsEntry = vsyscallMapsEntry + "Size: 4 kB\n" + @@ -62,7 +62,7 @@ func (mm *MemoryManager) NeedsUpdate(generation int64) bool { func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer) { mm.mappingMu.RLock() defer mm.mappingMu.RUnlock() - var start usermem.Addr + var start hostarch.Addr for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { mm.appendVMAMapsEntryLocked(ctx, vseg, buf) @@ -88,9 +88,9 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile mm.mappingMu.RLock() defer mm.mappingMu.RUnlock() var data []seqfile.SeqData - var start usermem.Addr + var start hostarch.Addr if handle != nil { - start = *handle.(*usermem.Addr) + start = *handle.(*hostarch.Addr) } for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { vmaAddr := vseg.End() @@ -177,7 +177,7 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffer) { mm.mappingMu.RLock() defer mm.mappingMu.RUnlock() - var start usermem.Addr + var start hostarch.Addr for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf) @@ -196,9 +196,9 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil mm.mappingMu.RLock() defer mm.mappingMu.RUnlock() var data []seqfile.SeqData - var start usermem.Addr + var start hostarch.Addr if handle != nil { - start = *handle.(*usermem.Addr) + start = *handle.(*hostarch.Addr) } for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { vmaAddr := vseg.End() @@ -279,8 +279,8 @@ func (mm *MemoryManager) vmaSmapsEntryIntoLocked(ctx context.Context, vseg vmaIt // Swap is not implemented. fmt.Fprintf(b, "Swap: %8d kB\n", 0) fmt.Fprintf(b, "SwapPss: %8d kB\n", 0) - fmt.Fprintf(b, "KernelPageSize: %8d kB\n", usermem.PageSize/1024) - fmt.Fprintf(b, "MMUPageSize: %8d kB\n", usermem.PageSize/1024) + fmt.Fprintf(b, "KernelPageSize: %8d kB\n", hostarch.PageSize/1024) + fmt.Fprintf(b, "MMUPageSize: %8d kB\n", hostarch.PageSize/1024) locked := rss if vma.mlockMode == memmap.MLockNone { locked = 0 diff --git a/pkg/sentry/mm/shm.go b/pkg/sentry/mm/shm.go index 6432731d4..3130be80c 100644 --- a/pkg/sentry/mm/shm.go +++ b/pkg/sentry/mm/shm.go @@ -16,13 +16,13 @@ package mm import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/shm" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // DetachShm unmaps a sysv shared memory segment. -func (mm *MemoryManager) DetachShm(ctx context.Context, addr usermem.Addr) error { +func (mm *MemoryManager) DetachShm(ctx context.Context, addr hostarch.Addr) error { if addr != addr.RoundDown() { // "... shmaddr is not aligned on a page boundary." - man shmdt(2) return syserror.EINVAL @@ -52,7 +52,7 @@ func (mm *MemoryManager) DetachShm(ctx context.Context, addr usermem.Addr) error } // Remove all vmas that could have been created by the same attach. - end := addr + usermem.Addr(detached.EffectiveSize()) + end := addr + hostarch.Addr(detached.EffectiveSize()) for vseg.Ok() && vseg.End() <= end { vma := vseg.ValuePtr() if vma.mappable == detached && uint64(vseg.Start()-addr) == vma.off { diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 48d8b6a2b..e748b7ff8 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -16,11 +16,11 @@ package mm import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // SpecialMappable implements memmap.MappingIdentity and memmap.Mappable with @@ -77,21 +77,21 @@ func (m *SpecialMappable) Msync(ctx context.Context, mr memmap.MappableRange) er } // AddMapping implements memmap.Mappable.AddMapping. -func (*SpecialMappable) AddMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) error { +func (*SpecialMappable) AddMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, uint64, bool) error { return nil } // RemoveMapping implements memmap.Mappable.RemoveMapping. -func (*SpecialMappable) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) { +func (*SpecialMappable) RemoveMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, uint64, bool) { } // CopyMapping implements memmap.Mappable.CopyMapping. -func (*SpecialMappable) CopyMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, usermem.AddrRange, uint64, bool) error { +func (*SpecialMappable) CopyMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, hostarch.AddrRange, uint64, bool) error { return nil } // Translate implements memmap.Mappable.Translate. -func (m *SpecialMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (m *SpecialMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { var err error if required.End > m.fr.Length() { err = &memmap.BusError{syserror.EFAULT} @@ -102,7 +102,7 @@ func (m *SpecialMappable) Translate(ctx context.Context, required, optional memm Source: source, File: m.mfp.MemoryFile(), Offset: m.fr.Start + source.Start, - Perms: usermem.AnyAccess, + Perms: hostarch.AnyAccess, }, }, err } @@ -146,7 +146,7 @@ func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*Spec if length == 0 { return nil, syserror.EINVAL } - alignedLen, ok := usermem.Addr(length).RoundUp() + alignedLen, ok := hostarch.Addr(length).RoundUp() if !ok { return nil, syserror.EINVAL } diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 69e37330b..7ad6b7c21 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -21,20 +21,20 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // HandleUserFault handles an application page fault. sp is the faulting // application thread's stack pointer. // // Preconditions: mm.as != nil. -func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr usermem.Addr, at usermem.AccessType, sp usermem.Addr) error { - ar, ok := addr.RoundDown().ToRange(usermem.PageSize) +func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr hostarch.Addr, at hostarch.AccessType, sp hostarch.Addr) error { + ar, ok := addr.RoundDown().ToRange(hostarch.PageSize) if !ok { return syserror.EFAULT } @@ -72,11 +72,11 @@ func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr usermem.Addr, } // MMap establishes a memory mapping. -func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (usermem.Addr, error) { +func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostarch.Addr, error) { if opts.Length == 0 { return 0, syserror.EINVAL } - length, ok := usermem.Addr(opts.Length).RoundUp() + length, ok := hostarch.Addr(opts.Length).RoundUp() if !ok { return 0, syserror.ENOMEM } @@ -84,7 +84,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme if opts.Mappable != nil { // Offset must be aligned. - if usermem.Addr(opts.Offset).RoundDown() != usermem.Addr(opts.Offset) { + if hostarch.Addr(opts.Offset).RoundDown() != hostarch.Addr(opts.Offset) { return 0, syserror.EINVAL } // Offset + length must not overflow. @@ -157,7 +157,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme // Preconditions: // * mm.mappingMu must be locked. // * vseg.Range().IsSupersetOf(ar). -func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) { +func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, precommit bool) { if !vseg.ValuePtr().effectivePerms.Any() { // Linux doesn't populate inaccessible pages. See // mm/gup.c:populate_vma_page_range. @@ -175,7 +175,7 @@ func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar u } // Ensure that we have usable pmas. - pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, usermem.NoAccess) + pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, hostarch.NoAccess) if err != nil { // mm/util.c:vm_mmap_pgoff() ignores the error, if any, from // mm/gup.c:mm_populate(). If it matters, we'll get it again when @@ -203,7 +203,7 @@ func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar u // * vseg.Range().IsSupersetOf(ar). // // Postconditions: mm.mappingMu will be unlocked. -func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) { +func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, precommit bool) { // See populateVMA above for commentary. if !vseg.ValuePtr().effectivePerms.Any() { mm.mappingMu.Unlock() @@ -221,7 +221,7 @@ func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaItera // mm.mappingMu doesn't need to be write-locked for getPMAsLocked, and it // isn't needed at all for mapASLocked. mm.mappingMu.DowngradeLock() - pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, usermem.NoAccess) + pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, hostarch.NoAccess) mm.mappingMu.RUnlock() if err != nil { mm.activeMu.Unlock() @@ -234,7 +234,7 @@ func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaItera } // MapStack allocates the initial process stack. -func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error) { +func (mm *MemoryManager) MapStack(ctx context.Context) (hostarch.AddrRange, error) { // maxStackSize is the maximum supported process stack size in bytes. // // This limit exists because stack growing isn't implemented, so the entire @@ -242,7 +242,7 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error const maxStackSize = 128 << 20 stackSize := limits.FromContext(ctx).Get(limits.Stack) - r, ok := usermem.Addr(stackSize.Cur).RoundUp() + r, ok := hostarch.Addr(stackSize.Cur).RoundUp() sz := uint64(r) if !ok { // RLIM_INFINITY rounds up to 0. @@ -251,16 +251,16 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error ctx.Warningf("Capping stack size from RLIMIT_STACK of %v down to %v.", sz, maxStackSize) sz = maxStackSize } else if sz == 0 { - return usermem.AddrRange{}, syserror.ENOMEM + return hostarch.AddrRange{}, syserror.ENOMEM } - szaddr := usermem.Addr(sz) + szaddr := hostarch.Addr(sz) ctx.Debugf("Allocating stack with size of %v bytes", sz) // Determine the stack's desired location. Unlike Linux, address // randomization can't be disabled. - stackEnd := mm.layout.MaxAddr - usermem.Addr(mrand.Int63n(int64(mm.layout.MaxStackRand))).RoundDown() + stackEnd := mm.layout.MaxAddr - hostarch.Addr(mrand.Int63n(int64(mm.layout.MaxStackRand))).RoundDown() if stackEnd < szaddr { - return usermem.AddrRange{}, syserror.ENOMEM + return hostarch.AddrRange{}, syserror.ENOMEM } stackStart := stackEnd - szaddr mm.mappingMu.Lock() @@ -268,8 +268,8 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error _, ar, err := mm.createVMALocked(ctx, memmap.MMapOpts{ Length: sz, Addr: stackStart, - Perms: usermem.ReadWrite, - MaxPerms: usermem.AnyAccess, + Perms: hostarch.ReadWrite, + MaxPerms: hostarch.AnyAccess, Private: true, GrowsDown: true, MLockMode: mm.defMLockMode, @@ -279,14 +279,14 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error } // MUnmap implements the semantics of Linux's munmap(2). -func (mm *MemoryManager) MUnmap(ctx context.Context, addr usermem.Addr, length uint64) error { +func (mm *MemoryManager) MUnmap(ctx context.Context, addr hostarch.Addr, length uint64) error { if addr != addr.RoundDown() { return syserror.EINVAL } if length == 0 { return syserror.EINVAL } - la, ok := usermem.Addr(length).RoundUp() + la, ok := hostarch.Addr(length).RoundUp() if !ok { return syserror.EINVAL } @@ -308,7 +308,7 @@ type MRemapOpts struct { // NewAddr is the new address for the remapping. NewAddr is ignored unless // Move is MMRemapMustMove. - NewAddr usermem.Addr + NewAddr hostarch.Addr } // MRemapMoveMode controls MRemap's moving behavior. @@ -328,7 +328,7 @@ const ( ) // MRemap implements the semantics of Linux's mremap(2). -func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSize uint64, newSize uint64, opts MRemapOpts) (usermem.Addr, error) { +func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldSize uint64, newSize uint64, opts MRemapOpts) (hostarch.Addr, error) { // "Note that old_address has to be page aligned." - mremap(2) if oldAddr.RoundDown() != oldAddr { return 0, syserror.EINVAL @@ -336,9 +336,9 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi // Linux treats an old_size that rounds up to 0 as 0, which is otherwise a // valid size. However, new_size can't be 0 after rounding. - oldSizeAddr, _ := usermem.Addr(oldSize).RoundUp() + oldSizeAddr, _ := hostarch.Addr(oldSize).RoundUp() oldSize = uint64(oldSizeAddr) - newSizeAddr, ok := usermem.Addr(newSize).RoundUp() + newSizeAddr, ok := hostarch.Addr(newSize).RoundUp() if !ok || newSizeAddr == 0 { return 0, syserror.EINVAL } @@ -392,8 +392,8 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi if newSize < oldSize { // If oldAddr+oldSize didn't overflow, oldAddr+newSize can't // either. - newEnd := oldAddr + usermem.Addr(newSize) - mm.unmapLocked(ctx, usermem.AddrRange{newEnd, oldEnd}) + newEnd := oldAddr + hostarch.Addr(newSize) + mm.unmapLocked(ctx, hostarch.AddrRange{newEnd, oldEnd}) } return oldAddr, nil } @@ -438,7 +438,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi } // Find a location for the new mapping. - var newAR usermem.AddrRange + var newAR hostarch.AddrRange switch opts.Move { case MRemapMayMove: newAddr, err := mm.findAvailableLocked(newSize, findAvailableOpts{}) @@ -457,7 +457,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi if !ok { return 0, syserror.EINVAL } - if (usermem.AddrRange{oldAddr, oldEnd}).Overlaps(newAR) { + if (hostarch.AddrRange{oldAddr, oldEnd}).Overlaps(newAR) { return 0, syserror.EINVAL } @@ -479,8 +479,8 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi // correct: compare Linux's mm/mremap.c:mremap_to() => do_munmap(), // vma_to_resize(). if newSize < oldSize { - oldNewEnd := oldAddr + usermem.Addr(newSize) - mm.unmapLocked(ctx, usermem.AddrRange{oldNewEnd, oldEnd}) + oldNewEnd := oldAddr + hostarch.Addr(newSize) + mm.unmapLocked(ctx, hostarch.AddrRange{oldNewEnd, oldEnd}) oldEnd = oldNewEnd } @@ -488,7 +488,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi vseg = mm.vmas.FindSegment(oldAddr) } - oldAR := usermem.AddrRange{oldAddr, oldEnd} + oldAR := hostarch.AddrRange{oldAddr, oldEnd} // Check that oldEnd maps to the same vma as oldAddr. if vseg.End() < oldEnd { @@ -588,14 +588,14 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi } // MProtect implements the semantics of Linux's mprotect(2). -func (mm *MemoryManager) MProtect(addr usermem.Addr, length uint64, realPerms usermem.AccessType, growsDown bool) error { +func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms hostarch.AccessType, growsDown bool) error { if addr.RoundDown() != addr { return syserror.EINVAL } if length == 0 { return nil } - rlength, ok := usermem.Addr(length).RoundUp() + rlength, ok := hostarch.Addr(length).RoundUp() if !ok { return syserror.ENOMEM } @@ -692,19 +692,19 @@ func (mm *MemoryManager) MProtect(addr usermem.Addr, length uint64, realPerms us } // BrkSetup sets mm's brk address to addr and its brk size to 0. -func (mm *MemoryManager) BrkSetup(ctx context.Context, addr usermem.Addr) { +func (mm *MemoryManager) BrkSetup(ctx context.Context, addr hostarch.Addr) { mm.mappingMu.Lock() defer mm.mappingMu.Unlock() // Unmap the existing brk. if mm.brk.Length() != 0 { mm.unmapLocked(ctx, mm.brk) } - mm.brk = usermem.AddrRange{addr, addr} + mm.brk = hostarch.AddrRange{addr, addr} } // Brk implements the semantics of Linux's brk(2), except that it returns an // error on failure. -func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Addr, error) { +func (mm *MemoryManager) Brk(ctx context.Context, addr hostarch.Addr) (hostarch.Addr, error) { mm.mappingMu.Lock() // Can't defer mm.mappingMu.Unlock(); see below. @@ -741,8 +741,8 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Ad Fixed: true, // Compare Linux's // arch/x86/include/asm/page_types.h:VM_DATA_DEFAULT_FLAGS. - Perms: usermem.ReadWrite, - MaxPerms: usermem.AnyAccess, + Perms: hostarch.ReadWrite, + MaxPerms: hostarch.AnyAccess, Private: true, // Linux: mm/mmap.c:sys_brk() => do_brk_flags() includes // mm->def_flags. @@ -762,7 +762,7 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Ad } case newbrkpg < oldbrkpg: - mm.unmapLocked(ctx, usermem.AddrRange{newbrkpg, oldbrkpg}) + mm.unmapLocked(ctx, hostarch.AddrRange{newbrkpg, oldbrkpg}) fallthrough default: @@ -775,9 +775,9 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Ad // MLock implements the semantics of Linux's mlock()/mlock2()/munlock(), // depending on mode. -func (mm *MemoryManager) MLock(ctx context.Context, addr usermem.Addr, length uint64, mode memmap.MLockMode) error { +func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length uint64, mode memmap.MLockMode) error { // Linux allows this to overflow. - la, _ := usermem.Addr(length + addr.PageOffset()).RoundUp() + la, _ := hostarch.Addr(length + addr.PageOffset()).RoundUp() ar, ok := addr.RoundDown().ToRange(uint64(la)) if !ok { return syserror.EINVAL @@ -850,7 +850,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr usermem.Addr, length ui mm.mappingMu.RUnlock() return syserror.ENOMEM } - _, _, err := mm.getPMAsLocked(ctx, vseg, vseg.Range().Intersect(ar), usermem.NoAccess) + _, _, err := mm.getPMAsLocked(ctx, vseg, vseg.Range().Intersect(ar), hostarch.NoAccess) if err != nil { mm.activeMu.Unlock() mm.mappingMu.RUnlock() @@ -945,7 +945,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error mm.mappingMu.DowngradeLock() for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() { if vseg.ValuePtr().effectivePerms.Any() { - mm.getPMAsLocked(ctx, vseg, vseg.Range(), usermem.NoAccess) + mm.getPMAsLocked(ctx, vseg, vseg.Range(), hostarch.NoAccess) } } @@ -965,7 +965,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error } // NumaPolicy implements the semantics of Linux's get_mempolicy(MPOL_F_ADDR). -func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (linux.NumaPolicy, uint64, error) { +func (mm *MemoryManager) NumaPolicy(addr hostarch.Addr) (linux.NumaPolicy, uint64, error) { mm.mappingMu.RLock() defer mm.mappingMu.RUnlock() vseg := mm.vmas.FindSegment(addr) @@ -977,12 +977,12 @@ func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (linux.NumaPolicy, uint64 } // SetNumaPolicy implements the semantics of Linux's mbind(). -func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error { +func (mm *MemoryManager) SetNumaPolicy(addr hostarch.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error { if !addr.IsPageAligned() { return syserror.EINVAL } // Linux allows this to overflow. - la, _ := usermem.Addr(length).RoundUp() + la, _ := hostarch.Addr(length).RoundUp() ar, ok := addr.ToRange(uint64(la)) if !ok { return syserror.EINVAL @@ -1018,7 +1018,7 @@ func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy } // SetDontFork implements the semantics of madvise MADV_DONTFORK. -func (mm *MemoryManager) SetDontFork(addr usermem.Addr, length uint64, dontfork bool) error { +func (mm *MemoryManager) SetDontFork(addr hostarch.Addr, length uint64, dontfork bool) error { ar, ok := addr.ToRange(length) if !ok { return syserror.EINVAL @@ -1044,7 +1044,7 @@ func (mm *MemoryManager) SetDontFork(addr usermem.Addr, length uint64, dontfork } // Decommit implements the semantics of Linux's madvise(MADV_DONTNEED). -func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error { +func (mm *MemoryManager) Decommit(addr hostarch.Addr, length uint64) error { ar, ok := addr.ToRange(length) if !ok { return syserror.EINVAL @@ -1112,14 +1112,14 @@ type MSyncOpts struct { } // MSync implements the semantics of Linux's msync(). -func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length uint64, opts MSyncOpts) error { +func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length uint64, opts MSyncOpts) error { if addr != addr.RoundDown() { return syserror.EINVAL } if length == 0 { return nil } - la, ok := usermem.Addr(length).RoundUp() + la, ok := hostarch.Addr(length).RoundUp() if !ok { return syserror.ENOMEM } @@ -1188,7 +1188,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length ui } // GetSharedFutexKey is used by kernel.Task.GetSharedKey. -func (mm *MemoryManager) GetSharedFutexKey(ctx context.Context, addr usermem.Addr) (futex.Key, error) { +func (mm *MemoryManager) GetSharedFutexKey(ctx context.Context, addr hostarch.Addr) (futex.Key, error) { ar, ok := addr.ToRange(4) // sizeof(int32). if !ok { return futex.Key{}, syserror.EFAULT @@ -1196,7 +1196,7 @@ func (mm *MemoryManager) GetSharedFutexKey(ctx context.Context, addr usermem.Add mm.mappingMu.RLock() defer mm.mappingMu.RUnlock() - vseg, _, err := mm.getVMAsLocked(ctx, ar, usermem.Read, false) + vseg, _, err := mm.getVMAsLocked(ctx, ar, hostarch.Read, false) if err != nil { return futex.Key{}, err } @@ -1230,7 +1230,7 @@ func (mm *MemoryManager) VirtualMemorySize() uint64 { // VirtualMemorySizeRange returns the combined length in bytes of all mappings // in ar in mm. -func (mm *MemoryManager) VirtualMemorySizeRange(ar usermem.AddrRange) uint64 { +func (mm *MemoryManager) VirtualMemorySizeRange(ar hostarch.AddrRange) uint64 { mm.mappingMu.RLock() defer mm.mappingMu.RUnlock() return uint64(mm.vmas.SpanRange(ar)) diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index b8df72813..0d019e41d 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -19,18 +19,18 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Preconditions: // * mm.mappingMu must be locked for writing. // * opts must be valid as defined by the checks in MMap. -func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOpts) (vmaIterator, usermem.AddrRange, error) { +func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOpts) (vmaIterator, hostarch.AddrRange, error) { if opts.MaxPerms != opts.MaxPerms.Effective() { panic(fmt.Sprintf("Non-effective MaxPerms %s cannot be enforced", opts.MaxPerms)) } @@ -47,7 +47,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp if opts.Force && opts.Unmap && opts.Fixed { addr = opts.Addr } else { - return vmaIterator{}, usermem.AddrRange{}, err + return vmaIterator{}, hostarch.AddrRange{}, err } } ar, _ := addr.ToRange(opts.Length) @@ -58,7 +58,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp newUsageAS -= uint64(mm.vmas.SpanRange(ar)) } if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS { - return vmaIterator{}, usermem.AddrRange{}, syserror.ENOMEM + return vmaIterator{}, hostarch.AddrRange{}, syserror.ENOMEM } if opts.MLockMode != memmap.MLockNone { @@ -66,14 +66,14 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) { mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur if mlockLimit == 0 { - return vmaIterator{}, usermem.AddrRange{}, syserror.EPERM + return vmaIterator{}, hostarch.AddrRange{}, syserror.EPERM } newLockedAS := mm.lockedAS + opts.Length if opts.Unmap { newLockedAS -= mm.mlockedBytesRangeLocked(ar) } if newLockedAS > mlockLimit { - return vmaIterator{}, usermem.AddrRange{}, syserror.EAGAIN + return vmaIterator{}, hostarch.AddrRange{}, syserror.EAGAIN } } } @@ -93,7 +93,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp // The expression for writable is vma.canWriteMappableLocked(), but we // don't yet have a vma. if err := opts.Mappable.AddMapping(ctx, mm, ar, opts.Offset, !opts.Private && opts.MaxPerms.Write); err != nil { - return vmaIterator{}, usermem.AddrRange{}, err + return vmaIterator{}, hostarch.AddrRange{}, err } } @@ -137,7 +137,7 @@ type findAvailableOpts struct { // // - Unmap allows existing guard pages in the returned range. - Addr usermem.Addr + Addr hostarch.Addr Fixed bool Unmap bool Map32Bit bool @@ -153,13 +153,13 @@ const ( // findAvailableLocked finds an allocatable range. // // Preconditions: mm.mappingMu must be locked. -func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOpts) (usermem.Addr, error) { +func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOpts) (hostarch.Addr, error) { if opts.Fixed { opts.Map32Bit = false } allowedAR := mm.applicationAddrRange() if opts.Map32Bit { - allowedAR = allowedAR.Intersect(usermem.AddrRange{map32Start, map32End}) + allowedAR = allowedAR.Intersect(hostarch.AddrRange{map32Start, map32End}) } // Does the provided suggestion work? @@ -181,33 +181,33 @@ func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOp } // Prefer hugepage alignment if a hugepage or more is requested. - alignment := uint64(usermem.PageSize) - if length >= usermem.HugePageSize { - alignment = usermem.HugePageSize + alignment := uint64(hostarch.PageSize) + if length >= hostarch.HugePageSize { + alignment = hostarch.HugePageSize } if opts.Map32Bit { return mm.findLowestAvailableLocked(length, alignment, allowedAR) } if mm.layout.DefaultDirection == arch.MmapBottomUp { - return mm.findLowestAvailableLocked(length, alignment, usermem.AddrRange{mm.layout.BottomUpBase, mm.layout.MaxAddr}) + return mm.findLowestAvailableLocked(length, alignment, hostarch.AddrRange{mm.layout.BottomUpBase, mm.layout.MaxAddr}) } - return mm.findHighestAvailableLocked(length, alignment, usermem.AddrRange{mm.layout.MinAddr, mm.layout.TopDownBase}) + return mm.findHighestAvailableLocked(length, alignment, hostarch.AddrRange{mm.layout.MinAddr, mm.layout.TopDownBase}) } -func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange { - return usermem.AddrRange{mm.layout.MinAddr, mm.layout.MaxAddr} +func (mm *MemoryManager) applicationAddrRange() hostarch.AddrRange { + return hostarch.AddrRange{mm.layout.MinAddr, mm.layout.MaxAddr} } // Preconditions: mm.mappingMu must be locked. -func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) { +func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds hostarch.AddrRange) (hostarch.Addr, error) { + for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(hostarch.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift up to match the alignment? if offset := uint64(gr.Start) % alignment; offset != 0 { if uint64(gr.Length()) >= length+alignment-offset { // Yes, we're aligned. - return gr.Start + usermem.Addr(alignment-offset), nil + return gr.Start + hostarch.Addr(alignment-offset), nil } } @@ -219,15 +219,15 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou } // Preconditions: mm.mappingMu must be locked. -func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) { +func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds hostarch.AddrRange) (hostarch.Addr, error) { + for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(hostarch.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift down to match the alignment? - start := gr.End - usermem.Addr(length) + start := gr.End - hostarch.Addr(length) if offset := uint64(start) % alignment; offset != 0 { - if gr.Start <= start-usermem.Addr(offset) { + if gr.Start <= start-hostarch.Addr(offset) { // Yes, we're aligned. - return start - usermem.Addr(offset), nil + return start - hostarch.Addr(offset), nil } } @@ -239,7 +239,7 @@ func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bo } // Preconditions: mm.mappingMu must be locked. -func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 { +func (mm *MemoryManager) mlockedBytesRangeLocked(ar hostarch.AddrRange) uint64 { var total uint64 for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() { if vseg.ValuePtr().mlockMode != memmap.MLockNone { @@ -264,7 +264,7 @@ func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 { // Preconditions: // * mm.mappingMu must be locked for reading; it may be temporarily unlocked. // * ar.Length() != 0. -func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) { +func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar hostarch.AddrRange, at hostarch.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -320,7 +320,7 @@ func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange // temporarily unlocked. // // Postconditions: ars is not mutated. -func (mm *MemoryManager) getVecVMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool) (usermem.AddrRangeSeq, error) { +func (mm *MemoryManager) getVecVMAsLocked(ctx context.Context, ars hostarch.AddrRangeSeq, at hostarch.AccessType, ignorePermissions bool) (hostarch.AddrRangeSeq, error) { for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() { ar := arsit.Head() if ar.Length() == 0 { @@ -339,7 +339,7 @@ func (mm *MemoryManager) getVecVMAsLocked(ctx context.Context, ars usermem.AddrR // // guardBytes is equivalent to Linux's stack_guard_gap after upstream // 1be7107fbe18 "mm: larger stack guard gap, between vmas". -const guardBytes = 256 * usermem.PageSize +const guardBytes = 256 * hostarch.PageSize // unmapLocked unmaps all addresses in ar and returns the resulting gap in // mm.vmas. @@ -348,7 +348,7 @@ const guardBytes = 256 * usermem.PageSize // * mm.mappingMu must be locked for writing. // * ar.Length() != 0. // * ar must be page-aligned. -func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator { +func (mm *MemoryManager) unmapLocked(ctx context.Context, ar hostarch.AddrRange) vmaGapIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -369,7 +369,7 @@ func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) // * mm.mappingMu must be locked for writing. // * ar.Length() != 0. // * ar must be page-aligned. -func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator { +func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar hostarch.AddrRange) vmaGapIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) @@ -426,12 +426,12 @@ func (vma *vma) isPrivateDataLocked() bool { // vmaSetFunctions implements segment.Functions for vmaSet. type vmaSetFunctions struct{} -func (vmaSetFunctions) MinKey() usermem.Addr { +func (vmaSetFunctions) MinKey() hostarch.Addr { return 0 } -func (vmaSetFunctions) MaxKey() usermem.Addr { - return ^usermem.Addr(0) +func (vmaSetFunctions) MaxKey() hostarch.Addr { + return ^hostarch.Addr(0) } func (vmaSetFunctions) ClearValue(vma *vma) { @@ -440,7 +440,7 @@ func (vmaSetFunctions) ClearValue(vma *vma) { vma.hint = "" } -func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRange, vma2 vma) (vma, bool) { +func (vmaSetFunctions) Merge(ar1 hostarch.AddrRange, vma1 vma, ar2 hostarch.AddrRange, vma2 vma) (vma, bool) { if vma1.mappable != vma2.mappable || (vma1.mappable != nil && vma1.off+uint64(ar1.Length()) != vma2.off) || vma1.realPerms != vma2.realPerms || @@ -462,7 +462,7 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa return vma1, true } -func (vmaSetFunctions) Split(ar usermem.AddrRange, v vma, split usermem.Addr) (vma, vma) { +func (vmaSetFunctions) Split(ar hostarch.AddrRange, v vma, split hostarch.Addr) (vma, vma) { v2 := v if v2.mappable != nil { v2.off += uint64(split - ar.Start) @@ -476,7 +476,7 @@ func (vmaSetFunctions) Split(ar usermem.AddrRange, v vma, split usermem.Addr) (v // Preconditions: // * vseg.ValuePtr().mappable != nil. // * vseg.Range().Contains(addr). -func (vseg vmaIterator) mappableOffsetAt(addr usermem.Addr) uint64 { +func (vseg vmaIterator) mappableOffsetAt(addr hostarch.Addr) uint64 { if checkInvariants { if !vseg.Ok() { panic("terminal vma iterator") @@ -503,7 +503,7 @@ func (vseg vmaIterator) mappableRange() memmap.MappableRange { // * vseg.ValuePtr().mappable != nil. // * vseg.Range().IsSupersetOf(ar). // * ar.Length() != 0. -func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRange { +func (vseg vmaIterator) mappableRangeOf(ar hostarch.AddrRange) memmap.MappableRange { if checkInvariants { if !vseg.Ok() { panic("terminal vma iterator") @@ -528,7 +528,7 @@ func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRan // * vseg.ValuePtr().mappable != nil. // * vseg.mappableRange().IsSupersetOf(mr). // * mr.Length() != 0. -func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange { +func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) hostarch.AddrRange { if checkInvariants { if !vseg.Ok() { panic("terminal vma iterator") @@ -546,7 +546,7 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange { vma := vseg.ValuePtr() vstart := vseg.Start() - return usermem.AddrRange{vstart + usermem.Addr(mr.Start-vma.off), vstart + usermem.Addr(mr.End-vma.off)} + return hostarch.AddrRange{vstart + hostarch.Addr(mr.Start-vma.off), vstart + hostarch.Addr(mr.End-vma.off)} } // seekNextLowerBound returns mm.vmas.LowerBoundSegment(addr), but does so by @@ -555,7 +555,7 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange { // Preconditions: // * mm.mappingMu must be locked. // * addr >= vseg.Start(). -func (vseg vmaIterator) seekNextLowerBound(addr usermem.Addr) vmaIterator { +func (vseg vmaIterator) seekNextLowerBound(addr hostarch.Addr) vmaIterator { if checkInvariants { if !vseg.Ok() { panic("terminal vma iterator") @@ -572,7 +572,7 @@ func (vseg vmaIterator) seekNextLowerBound(addr usermem.Addr) vmaIterator { // availableRange returns the subset of vgap.Range() in which new vmas may be // created without MMapOpts.Unmap == true. -func (vgap vmaGapIterator) availableRange() usermem.AddrRange { +func (vgap vmaGapIterator) availableRange() hostarch.AddrRange { ar := vgap.Range() next := vgap.NextSegment() if !next.Ok() || !next.ValuePtr().growsDown { @@ -580,7 +580,7 @@ func (vgap vmaGapIterator) availableRange() usermem.AddrRange { } // Exclude guard pages. if ar.Length() < guardBytes { - return usermem.AddrRange{ar.Start, ar.Start} + return hostarch.AddrRange{ar.Start, ar.Start} } ar.End -= guardBytes return ar diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index e5bf13c40..57d73d770 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -85,6 +85,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/memutil", "//pkg/safemem", @@ -106,5 +107,5 @@ go_test( size = "small", srcs = ["pgalloc_test.go"], library = ":pgalloc", - deps = ["//pkg/usermem"], + deps = ["//pkg/hostarch"], ) diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index a4af3e21b..b81292c46 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -31,6 +31,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostmm" @@ -38,7 +39,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // MemoryFile is a memmap.File whose pages may be allocated to arbitrary @@ -283,7 +283,7 @@ const ( chunkMask = chunkSize - 1 // maxPage is the highest 64-bit page. - maxPage = math.MaxUint64 &^ (usermem.PageSize - 1) + maxPage = math.MaxUint64 &^ (hostarch.PageSize - 1) ) // NewMemoryFile creates a MemoryFile backed by the given file. If @@ -344,7 +344,7 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) { m, _, errno := unix.Syscall6( unix.SYS_MMAP, 0, - usermem.PageSize, + hostarch.PageSize, unix.PROT_EXEC, unix.MAP_SHARED, file.Fd(), @@ -357,7 +357,7 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) { if _, _, errno := unix.Syscall( unix.SYS_MUNMAP, m, - usermem.PageSize, + hostarch.PageSize, 0); errno != 0 { panic(fmt.Sprintf("failed to unmap PROT_EXEC MemoryFile mapping: %v", errno)) } @@ -386,7 +386,7 @@ func (f *MemoryFile) Destroy() { // // Preconditions: length must be page-aligned and non-zero. func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) { - if length == 0 || length%usermem.PageSize != 0 { + if length == 0 || length%hostarch.PageSize != 0 { panic(fmt.Sprintf("invalid allocation length: %#x", length)) } @@ -395,9 +395,9 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File // Align hugepage-and-larger allocations on hugepage boundaries to try // to take advantage of hugetmpfs. - alignment := uint64(usermem.PageSize) - if length >= usermem.HugePageSize { - alignment = usermem.HugePageSize + alignment := uint64(hostarch.PageSize) + if length >= hostarch.HugePageSize { + alignment = hostarch.HugePageSize } // Find a range in the underlying file. @@ -524,13 +524,13 @@ func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r saf if err != nil { return memmap.FileRange{}, err } - dsts, err := f.MapInternal(fr, usermem.Write) + dsts, err := f.MapInternal(fr, hostarch.Write) if err != nil { f.DecRef(fr) return memmap.FileRange{}, err } n, err := safemem.ReadFullToBlocks(r, dsts) - un := uint64(usermem.Addr(n).RoundDown()) + un := uint64(hostarch.Addr(n).RoundDown()) if un < length { // Free unused memory and update fr to contain only the memory that is // still allocated. @@ -552,7 +552,7 @@ const ( // // Preconditions: fr.Length() > 0. func (f *MemoryFile) Decommit(fr memmap.FileRange) error { - if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { + if !fr.WellFormed() || fr.Length() == 0 || fr.Start%hostarch.PageSize != 0 || fr.End%hostarch.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -614,7 +614,7 @@ func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { // IncRef implements memmap.File.IncRef. func (f *MemoryFile) IncRef(fr memmap.FileRange) { - if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { + if !fr.WellFormed() || fr.Length() == 0 || fr.Start%hostarch.PageSize != 0 || fr.End%hostarch.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -633,7 +633,7 @@ func (f *MemoryFile) IncRef(fr memmap.FileRange) { // DecRef implements memmap.File.DecRef. func (f *MemoryFile) DecRef(fr memmap.FileRange) { - if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { + if !fr.WellFormed() || fr.Length() == 0 || fr.Start%hostarch.PageSize != 0 || fr.End%hostarch.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -669,7 +669,7 @@ func (f *MemoryFile) DecRef(fr memmap.FileRange) { } // MapInternal implements memmap.File.MapInternal. -func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (f *MemoryFile) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) { if !fr.WellFormed() || fr.Length() == 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -935,7 +935,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( // Ensure that we have sufficient buffer for the call // (one byte per page). The length of each slice must // be page-aligned. - bufLen := len(s) / usermem.PageSize + bufLen := len(s) / hostarch.PageSize if len(buf) < bufLen { buf = make([]byte, bufLen) } @@ -967,8 +967,8 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( } } committedFR := memmap.FileRange{ - Start: r.Start + uint64(i*usermem.PageSize), - End: r.Start + uint64(j*usermem.PageSize), + Start: r.Start + uint64(i*hostarch.PageSize), + End: r.Start + uint64(j*hostarch.PageSize), } // Advance seg to committedFR.Start. for seg.Ok() && seg.End() < committedFR.Start { diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go index 405db141f..8d2b7eb5e 100644 --- a/pkg/sentry/pgalloc/pgalloc_test.go +++ b/pkg/sentry/pgalloc/pgalloc_test.go @@ -17,12 +17,12 @@ package pgalloc import ( "testing" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) const ( - page = usermem.PageSize - hugepage = usermem.HugePageSize + page = hostarch.PageSize + hugepage = hostarch.HugePageSize topPage = (1 << 63) - page ) diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index e05c8d074..345cdde55 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -23,11 +23,11 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/state" "gvisor.dev/gvisor/pkg/state/wire" - "gvisor.dev/gvisor/pkg/usermem" ) // SaveTo writes f's state to the given stream. @@ -49,11 +49,11 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error { // Ensure that all pages that contain data have knownCommitted set, since // we only store knownCommitted pages below. - zeroPage := make([]byte, usermem.PageSize) + zeroPage := make([]byte, hostarch.PageSize) err := f.updateUsageLocked(0, func(bs []byte, committed []byte) error { - for pgoff := 0; pgoff < len(bs); pgoff += usermem.PageSize { - i := pgoff / usermem.PageSize - pg := bs[pgoff : pgoff+usermem.PageSize] + for pgoff := 0; pgoff < len(bs); pgoff += hostarch.PageSize { + i := pgoff / hostarch.PageSize + pg := bs[pgoff : pgoff+hostarch.PageSize] if !bytes.Equal(pg, zeroPage) { committed[i] = 1 continue diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index db7d55ef2..7125657b3 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/hostmm", diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 03a76eb9b..b307832fd 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -43,6 +43,7 @@ go_library( "//pkg/atomicbitops", "//pkg/context", "//pkg/cpuid", + "//pkg/hostarch", "//pkg/log", "//pkg/procid", "//pkg/ring0", @@ -56,7 +57,6 @@ go_library( "//pkg/sentry/platform/interrupt", "//pkg/sentry/time", "//pkg/sync", - "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], ) @@ -65,6 +65,7 @@ go_test( name = "kvm_test", srcs = [ "kvm_amd64_test.go", + "kvm_amd64_test.s", "kvm_arm64_test.go", "kvm_test.go", "virtual_map_test.go", @@ -76,6 +77,7 @@ go_test( "requires-kvm", ], deps = [ + "//pkg/hostarch", "//pkg/ring0", "//pkg/ring0/pagetables", "//pkg/sentry/arch", @@ -83,7 +85,6 @@ go_test( "//pkg/sentry/platform", "//pkg/sentry/platform/kvm/testutil", "//pkg/sentry/time", - "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index 25c21e843..5524e8727 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -18,11 +18,11 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // dirtySet tracks vCPUs for invalidation. @@ -118,7 +118,7 @@ type hostMapEntry struct { // +checkescape:hard,stack // //go:nosplit -func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) { +func (as *addressSpace) mapLocked(addr hostarch.Addr, m hostMapEntry, at hostarch.AccessType) (inv bool) { for m.length > 0 { physical, length, ok := translateToPhysical(m.addr) if !ok { @@ -144,14 +144,14 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem. }, physical) || inv m.addr += length m.length -= length - addr += usermem.Addr(length) + addr += hostarch.Addr(length) } return inv } // MapFile implements platform.AddressSpace.MapFile. -func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { +func (as *addressSpace) MapFile(addr hostarch.Addr, f memmap.File, fr memmap.FileRange, at hostarch.AccessType, precommit bool) error { as.mu.Lock() defer as.mu.Unlock() @@ -165,7 +165,7 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.File // We don't execute from application file-mapped memory, and guest page // tables don't care if we have execute permission (but they do need pages // to be readable). - bs, err := f.MapInternal(fr, usermem.AccessType{ + bs, err := f.MapInternal(fr, hostarch.AccessType{ Read: at.Read || at.Execute || precommit, Write: at.Write, }) @@ -187,7 +187,7 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.File // lookup in our host page tables for this translation. if precommit { s := b.ToSlice() - for i := 0; i < len(s); i += usermem.PageSize { + for i := 0; i < len(s); i += hostarch.PageSize { _ = s[i] // Touch to commit. } } @@ -201,7 +201,7 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.File length: uintptr(b.Len()), }, at) inv = inv || prev - addr += usermem.Addr(b.Len()) + addr += hostarch.Addr(b.Len()) } if inv { as.invalidate() @@ -215,12 +215,12 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.File // +checkescape:hard,stack // //go:nosplit -func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool { +func (as *addressSpace) unmapLocked(addr hostarch.Addr, length uint64) bool { return as.pageTables.Unmap(addr, uintptr(length)) } // Unmap unmaps the given range by calling pagetables.PageTables.Unmap. -func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) { +func (as *addressSpace) Unmap(addr hostarch.Addr, length uint64) { as.mu.Lock() defer as.mu.Unlock() diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index fd1131638..bb9967b9f 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -16,7 +16,6 @@ package kvm import ( "fmt" - "reflect" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/ring0" @@ -36,6 +35,14 @@ func sighandler() // dieArchSetup and the assembly implementation for dieTrampoline. func dieTrampoline() +// Return the start address of the functions above. +// +// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal +// wrapper function rather than the function itself. We must reference from +// assembly to get the ABI0 (i.e., primary) address. +func addrOfSighandler() uintptr +func addrOfDieTrampoline() uintptr + var ( // bounceSignal is the signal used for bouncing KVM. // @@ -87,10 +94,10 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { + if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } // Extract the address for the trampoline. - dieTrampolineAddr = reflect.ValueOf(dieTrampoline).Pointer() + dieTrampolineAddr = addrOfDieTrampoline() } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index 025ea93b5..953024600 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -81,8 +81,20 @@ fallback: MOVQ ·savedHandler(SB), AX JMP AX +// func addrOfSighandler() uintptr +TEXT ·addrOfSighandler(SB), $0-8 + MOVQ $·sighandler(SB), AX + MOVQ AX, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 PUSHQ BX // First argument (vCPU). PUSHQ AX // Fake the old RIP as caller. JMP ·dieHandler(SB) + +// func addrOfDieTrampoline() uintptr +TEXT ·addrOfDieTrampoline(SB), $0-8 + MOVQ $·dieTrampoline(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 09c7e88e5..308f2a951 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -92,6 +92,12 @@ fallback: MOVD ·savedHandler(SB), R7 B (R7) +// func addrOfSighandler() uintptr +TEXT ·addrOfSighandler(SB), $0-8 + MOVD $·sighandler(SB), R0 + MOVD R0, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 // R0: Fake the old PC as caller @@ -99,3 +105,9 @@ TEXT ·dieTrampoline(SB),NOSPLIT,$0 MOVD.P R1, 8(RSP) // R1: First argument (vCPU) MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller B ·dieHandler(SB) + +// func addrOfDieTrampoline() uintptr +TEXT ·addrOfDieTrampoline(SB), $0-8 + MOVD $·dieTrampoline(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index 37c53fa02..28a613a54 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -18,7 +18,7 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -47,7 +47,7 @@ func yield() { // //go:nosplit func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) { - alignedPhysical := physical &^ uintptr(usermem.PageSize-1) + alignedPhysical := physical &^ uintptr(hostarch.PageSize-1) for _, pr := range phyRegions { end := pr.physical + pr.length if physical < pr.physical || physical >= end { diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go index 706fa53dc..f4d4473a8 100644 --- a/pkg/sentry/platform/kvm/context.go +++ b/pkg/sentry/platform/kvm/context.go @@ -18,11 +18,11 @@ import ( "sync/atomic" pkgcontext "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" - "gvisor.dev/gvisor/pkg/usermem" ) // context is an implementation of the platform context. @@ -40,7 +40,7 @@ type context struct { } // Switch runs the provided context in the given address space. -func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) { +func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, hostarch.AccessType, error) { as := mm.AddressSpace() localAS := as.(*addressSpace) @@ -50,7 +50,7 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a // Enable interrupts (i.e. calls to vCPU.Notify). if !c.interrupt.Enable(cpu) { c.machine.Put(cpu) // Already preempted. - return nil, usermem.NoAccess, platform.ErrContextInterrupt + return nil, hostarch.NoAccess, platform.ErrContextInterrupt } // Set the active address space. diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index 92c05a9ad..aac0fdffe 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -20,11 +20,11 @@ import ( "os" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // userMemoryRegion is a region of physical memory. @@ -146,13 +146,13 @@ func (*KVM) MapUnit() uint64 { } // MinUserAddress returns the lowest available address. -func (*KVM) MinUserAddress() usermem.Addr { - return usermem.PageSize +func (*KVM) MinUserAddress() hostarch.Addr { + return hostarch.PageSize } // MaxUserAddress returns the first address that may not be used. -func (*KVM) MaxUserAddress() usermem.Addr { - return usermem.Addr(ring0.MaximumUserAddress) +func (*KVM) MaxUserAddress() hostarch.Addr { + return hostarch.Addr(ring0.MaximumUserAddress) } // NewAddressSpace returns a new pagetable root. diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index e44e995a0..b8dd1e4a5 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -49,3 +49,40 @@ func TestSegments(t *testing.T) { return false }) } + +// stmxcsr reads the MXCSR control and status register. +func stmxcsr(addr *uint32) + +func TestMXCSR(t *testing.T) { + applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { + var si arch.SignalInfo + switchOpts := ring0.SwitchOpts{ + Registers: regs, + FloatingPointState: &dummyFPState, + PageTables: pt, + FullRestore: true, + } + + const mxcsrControllMask = uint32(0x1f80) + mxcsrBefore := uint32(0) + mxcsrAfter := uint32(0) + stmxcsr(&mxcsrBefore) + if mxcsrBefore == 0 { + // goruntime sets mxcsr to 0x1f80 and it never changes + // the control configuration. + panic("mxcsr is zero") + } + switchOpts.FloatingPointState.SetMXCSR(0) + if _, err := c.SwitchToUser( + switchOpts, &si); err == platform.ErrContextInterrupt { + return true // Retry. + } else if err != nil { + t.Errorf("application syscall failed: %v", err) + } + stmxcsr(&mxcsrAfter) + if mxcsrAfter&mxcsrControllMask != mxcsrBefore&mxcsrControllMask { + t.Errorf("mxcsr = %x (expected %x)", mxcsrBefore, mxcsrAfter) + } + return false + }) +} diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/sentry/platform/kvm/kvm_amd64_test.s index d0f58cfaf..8e9079867 100644 --- a/pkg/tcpip/transport/tcp/cubic_state.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.s @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,18 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tcp +#include "textflag.h" -import ( - "time" -) - -// saveT is invoked by stateify. -func (c *cubicState) saveT() unixTime { - return unixTime{c.t.Unix(), c.t.UnixNano()} -} - -// loadT is invoked by stateify. -func (c *cubicState) loadT(unix unixTime) { - c.t = time.Unix(unix.second, unix.nano) -} +// stmxcsr reads the MXCSR control and status register. +TEXT ·stmxcsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), SI + STMXCSR (SI) + RET diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 5bce16dde..ceff09a60 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -22,6 +22,7 @@ import ( "time" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -29,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" ktime "gvisor.dev/gvisor/pkg/sentry/time" - "gvisor.dev/gvisor/pkg/usermem" ) var dummyFPState = fpu.NewState() @@ -142,8 +142,8 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func // done for regular user code, but is fine for test // purposes.) applyPhysicalRegions(func(pr physicalRegion) bool { - pt.Map(usermem.Addr(pr.virtual), pr.length, pagetables.MapOpts{ - AccessType: usermem.AnyAccess, + pt.Map(hostarch.Addr(pr.virtual), pr.length, pagetables.MapOpts{ + AccessType: hostarch.AnyAccess, User: true, }, pr.physical) return true // Keep iterating. @@ -351,7 +351,7 @@ func TestInvalidate(t *testing.T) { break // Done. } // Unmap the page containing data & invalidate. - pt.Unmap(usermem.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(usermem.PageSize-1)), usermem.PageSize) + pt.Unmap(hostarch.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(hostarch.PageSize-1)), hostarch.PageSize) for { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index cc0fb9892..99f036bba 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -21,13 +21,13 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // machine contains state associated with the VM as a whole. @@ -235,9 +235,9 @@ func newMachine(vm int) (*machine, error) { applyPhysicalRegions(func(pr physicalRegion) bool { // Map everything in the lower half. m.kernel.PageTables.Map( - usermem.Addr(pr.virtual), + hostarch.Addr(pr.virtual), pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, + pagetables.MapOpts{AccessType: hostarch.AnyAccess}, pr.physical) return true // Keep iterating. @@ -444,7 +444,7 @@ func (m *machine) Get() *vCPU { } // The vCPU is not be able to transition to - // vCPUGuest|vCPUUser or to vCPUUser because that + // vCPUGuest|vCPUWaiter or to vCPUUser because that // transition requires holding the machine mutex, as we // do now. There is no path to register a waiter on // just the vCPUReady state. diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 31f56e3c2..d7abfefb4 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -24,13 +24,13 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" - "gvisor.dev/gvisor/pkg/usermem" ) // initArchState initializes architecture-specific state. @@ -41,7 +41,7 @@ func (m *machine) initArchState() error { unix.SYS_IOCTL, uintptr(m.fd), _KVM_SET_TSS_ADDR, - uintptr(reservedMemory-(3*usermem.PageSize))); errno != 0 { + uintptr(reservedMemory-(3*hostarch.PageSize))); errno != 0 { return errno } @@ -261,19 +261,19 @@ func (c *vCPU) setSystemTime() error { // nonCanonical generates a canonical address return. // //go:nosplit -func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { +func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { *info = arch.SignalInfo{ Signo: signal, Code: arch.SignalInfoKernel, } info.SetAddr(addr) // Include address. - return usermem.NoAccess, platform.ErrContextSignal + return hostarch.NoAccess, platform.ErrContextSignal } // fault generates an appropriate fault return. // //go:nosplit -func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { +func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { bluepill(c) // Probably no-op, but may not be. faultAddr := ring0.ReadCR2() code, user := c.ErrorCode() @@ -281,12 +281,12 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, e // The last fault serviced by this CPU was not a user // fault, so we can't reliably trust the faultAddr or // the code provided here. We need to re-execute. - return usermem.NoAccess, platform.ErrContextInterrupt + return hostarch.NoAccess, platform.ErrContextInterrupt } // Reset the pointed SignalInfo. *info = arch.SignalInfo{Signo: signal} info.SetAddr(uint64(faultAddr)) - accessType := usermem.AccessType{ + accessType := hostarch.AccessType{ Read: code&(1<<1) == 0, Write: code&(1<<1) != 0, Execute: code&(1<<4) != 0, @@ -315,14 +315,14 @@ func loadByte(ptr *byte) byte { //go:nosplit func prefaultFloatingPointState(data *fpu.State) { size := len(*data) - for i := 0; i < size; i += usermem.PageSize { + for i := 0; i < size; i += hostarch.PageSize { loadByte(&(*data)[i]) } loadByte(&(*data)[size-1]) } // SwitchToUser unpacks architectural-details. -func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) { +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Rip) { return nonCanonical(regs.Rip, int32(unix.SIGSEGV), info) @@ -358,7 +358,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) switch vector { case ring0.Syscall, ring0.SyscallInt80: // Fast path: system call executed. - return usermem.NoAccess, nil + return hostarch.NoAccess, nil case ring0.PageFault: return c.fault(int32(unix.SIGSEGV), info) @@ -369,7 +369,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) Code: 1, // TRAP_BRKPT (breakpoint). } info.SetAddr(switchOpts.Registers.Rip) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal + return hostarch.AccessType{}, platform.ErrContextSignal case ring0.GeneralProtectionFault, ring0.SegmentNotPresent, @@ -385,9 +385,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) // When CPUID faulting is enabled, we will generate a #GP(0) when // userspace executes a CPUID instruction. This is handled above, // because we need to be able to map and read user memory. - return usermem.AccessType{}, platform.ErrContextSignalCPUID + return hostarch.AccessType{}, platform.ErrContextSignalCPUID } - return usermem.AccessType{}, platform.ErrContextSignal + return hostarch.AccessType{}, platform.ErrContextSignal case ring0.InvalidOpcode: *info = arch.SignalInfo{ @@ -395,7 +395,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) Code: 1, // ILL_ILLOPC (illegal opcode). } info.SetAddr(switchOpts.Registers.Rip) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal + return hostarch.AccessType{}, platform.ErrContextSignal case ring0.DivideByZero: *info = arch.SignalInfo{ @@ -403,7 +403,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) Code: 1, // FPE_INTDIV (divide by zero). } info.SetAddr(switchOpts.Registers.Rip) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal + return hostarch.AccessType{}, platform.ErrContextSignal case ring0.Overflow: *info = arch.SignalInfo{ @@ -411,7 +411,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) Code: 2, // FPE_INTOVF (integer overflow). } info.SetAddr(switchOpts.Registers.Rip) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal + return hostarch.AccessType{}, platform.ErrContextSignal case ring0.X87FloatingPointException, ring0.SIMDFloatingPointException: @@ -420,17 +420,17 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) Code: 7, // FPE_FLTINV (invalid operation). } info.SetAddr(switchOpts.Registers.Rip) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal + return hostarch.AccessType{}, platform.ErrContextSignal case ring0.Vector(bounce): // ring0.VirtualizationException - return usermem.NoAccess, platform.ErrContextInterrupt + return hostarch.NoAccess, platform.ErrContextInterrupt case ring0.AlignmentCheck: *info = arch.SignalInfo{ Signo: int32(unix.SIGBUS), Code: 2, // BUS_ADRERR (physical address does not exist). } - return usermem.NoAccess, platform.ErrContextSignal + return hostarch.NoAccess, platform.ErrContextSignal case ring0.NMI: // An NMI is generated only when a fault is not servicable by @@ -476,9 +476,9 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { panic("impossible translation") } pageTable.Map( - usermem.Addr(ring0.KernelStartAddress|r.virtual), + hostarch.Addr(ring0.KernelStartAddress|r.virtual), r.length, - pagetables.MapOpts{AccessType: usermem.Execute}, + pagetables.MapOpts{AccessType: hostarch.Execute}, physical) } }) @@ -489,9 +489,9 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { panic("impossible translation") } pageTable.Map( - usermem.Addr(ring0.KernelStartAddress|start), + hostarch.Addr(ring0.KernelStartAddress|start), regionLen, - pagetables.MapOpts{AccessType: usermem.ReadWrite}, + pagetables.MapOpts{AccessType: hostarch.ReadWrite}, physical) } } diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 2edc9d1b2..cd912f922 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -17,12 +17,12 @@ package kvm import ( + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/usermem" ) type vCPUArchState struct { @@ -47,15 +47,15 @@ const ( // Beyond a relatively small number, there are likely few perform // benefits, since the TLB has likely long since lost any translations // from more than a few PCIDs past. - poolPCIDs = 8 + poolPCIDs = 128 ) func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { applyPhysicalRegions(func(pr physicalRegion) bool { pageTable.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), + hostarch.Addr(ring0.KernelStartAddress|pr.virtual), pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess, Global: true}, + pagetables.MapOpts{AccessType: hostarch.AnyAccess, Global: true}, pr.physical) return true // Keep iterating. @@ -117,13 +117,13 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { // nonCanonical generates a canonical address return. // //go:nosplit -func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { +func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { *info = arch.SignalInfo{ Signo: signal, Code: arch.SignalInfoKernel, } info.SetAddr(addr) // Include address. - return usermem.NoAccess, platform.ErrContextSignal + return hostarch.NoAccess, platform.ErrContextSignal } // isInstructionAbort returns true if it is an instruction abort. @@ -148,7 +148,7 @@ func isWriteFault(code uint64) bool { // fault generates an appropriate fault return. // //go:nosplit -func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { +func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { bluepill(c) // Probably no-op, but may not be. faultAddr := c.GetFaultAddr() code, user := c.ErrorCode() @@ -157,7 +157,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, e // The last fault serviced by this CPU was not a user // fault, so we can't reliably trust the faultAddr or // the code provided here. We need to re-execute. - return usermem.NoAccess, platform.ErrContextInterrupt + return hostarch.NoAccess, platform.ErrContextInterrupt } // Reset the pointed SignalInfo. @@ -174,7 +174,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, e info.Code = 2 } - accessType := usermem.AccessType{ + accessType := hostarch.AccessType{ Read: !isWriteFault(uint64(code)), Write: isWriteFault(uint64(code)), Execute: isInstructionAbort(uint64(code)), diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index e7d5f3193..634e55ec0 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -23,12 +23,12 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/usermem" ) type kvmVcpuInit struct { @@ -209,7 +209,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { } // SwitchToUser unpacks architectural-details. -func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) { +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { return nonCanonical(regs.Pc, int32(unix.SIGSEGV), info) @@ -246,13 +246,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) switch vector { case ring0.Syscall: // Fast path: system call executed. - return usermem.NoAccess, nil + return hostarch.NoAccess, nil case ring0.PageFault: return c.fault(int32(unix.SIGSEGV), info) case ring0.El0ErrNMI: return c.fault(int32(unix.SIGBUS), info) case ring0.Vector(bounce): // ring0.VirtualizationException. - return usermem.NoAccess, platform.ErrContextInterrupt + return hostarch.NoAccess, platform.ErrContextInterrupt case ring0.El0SyncUndef: return c.fault(int32(unix.SIGILL), info) case ring0.El0SyncDbg: @@ -261,16 +261,16 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) Code: 1, // TRAP_BRKPT (breakpoint). } info.SetAddr(switchOpts.Registers.Pc) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal + return hostarch.AccessType{}, platform.ErrContextSignal case ring0.El0SyncSpPc: *info = arch.SignalInfo{ Signo: int32(unix.SIGBUS), Code: 2, // BUS_ADRERR (physical address does not exist). } - return usermem.NoAccess, platform.ErrContextSignal + return hostarch.NoAccess, platform.ErrContextSignal case ring0.El0SyncSys, ring0.El0SyncWfx: - return usermem.NoAccess, nil // skip for now. + return hostarch.NoAccess, nil // skip for now. default: panic(fmt.Sprintf("unexpected vector: 0x%x", vector)) } diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index 7376d8b8d..d812e6c26 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -19,9 +19,9 @@ import ( "sort" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/ring0" - "gvisor.dev/gvisor/pkg/usermem" ) type region struct { @@ -81,7 +81,7 @@ func fillAddressSpace() (excludedRegions []region) { // faultBlockSize, potentially causing up to faultBlockSize bytes in // internal fragmentation for each physical region. So we need to // account for this properly during allocation. - requiredAddr, ok := usermem.Addr(vSize - pSize + faultBlockSize).RoundUp() + requiredAddr, ok := hostarch.Addr(vSize - pSize + faultBlockSize).RoundUp() if !ok { panic(fmt.Sprintf( "overflow for vSize (%x) - pSize (%x) + faultBlockSize (%x)", @@ -99,7 +99,7 @@ func fillAddressSpace() (excludedRegions []region) { 0, 0) if errno != 0 { // Attempt half the size; overflow not possible. - currentAddr, _ := usermem.Addr(current >> 1).RoundUp() + currentAddr, _ := hostarch.Addr(current >> 1).RoundUp() current = uintptr(currentAddr) continue } @@ -134,8 +134,8 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica return } if virtual == 0 { - virtual += usermem.PageSize - length -= usermem.PageSize + virtual += hostarch.PageSize + length -= hostarch.PageSize } if end := virtual + length; end > ring0.MaximumUserAddress { length -= (end - ring0.MaximumUserAddress) diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go index 4dcdbf8a7..01d9eb39d 100644 --- a/pkg/sentry/platform/kvm/virtual_map.go +++ b/pkg/sentry/platform/kvm/virtual_map.go @@ -22,12 +22,12 @@ import ( "regexp" "strconv" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) type virtualRegion struct { region - accessType usermem.AccessType + accessType hostarch.AccessType shared bool offset uintptr filename string @@ -92,7 +92,7 @@ func applyVirtualRegions(fn func(vr virtualRegion)) error { virtual: uintptr(start), length: uintptr(end - start), }, - accessType: usermem.AccessType{ + accessType: hostarch.AccessType{ Read: read, Write: write, Execute: execute, diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go index 9b4545fdd..1f4a774f3 100644 --- a/pkg/sentry/platform/kvm/virtual_map_test.go +++ b/pkg/sentry/platform/kvm/virtual_map_test.go @@ -18,12 +18,12 @@ import ( "testing" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) type checker struct { ok bool - accessType usermem.AccessType + accessType hostarch.AccessType } func (c *checker) Containing(addr uintptr) func(virtualRegion) { @@ -46,7 +46,7 @@ func TestParseMaps(t *testing.T) { // MMap a new page. addr, _, errno := unix.RawSyscall6( - unix.SYS_MMAP, 0, usermem.PageSize, + unix.SYS_MMAP, 0, hostarch.PageSize, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_ANONYMOUS|unix.MAP_PRIVATE, 0, 0) if errno != 0 { @@ -55,19 +55,19 @@ func TestParseMaps(t *testing.T) { // Re-parse maps. if err := applyVirtualRegions(c.Containing(addr)); err != nil { - unix.RawSyscall(unix.SYS_MUNMAP, addr, usermem.PageSize, 0) + unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0) t.Fatalf("unexpected error: %v", err) } // Assert that it now does contain the region. if !c.ok { - unix.RawSyscall(unix.SYS_MUNMAP, addr, usermem.PageSize, 0) + unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0) t.Fatalf("updated map does not contain 0x%08x, expected true", addr) } // Map the region as PROT_NONE. newAddr, _, errno := unix.RawSyscall6( - unix.SYS_MMAP, addr, usermem.PageSize, + unix.SYS_MMAP, addr, hostarch.PageSize, unix.PROT_NONE, unix.MAP_ANONYMOUS|unix.MAP_FIXED|unix.MAP_PRIVATE, 0, 0) if errno != 0 { @@ -89,5 +89,5 @@ func TestParseMaps(t *testing.T) { } // Unmap the region. - unix.RawSyscall(unix.SYS_MUNMAP, addr, usermem.PageSize, 0) + unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0) } diff --git a/pkg/sentry/platform/mmap_min_addr.go b/pkg/sentry/platform/mmap_min_addr.go index 091c2e365..7335bd802 100644 --- a/pkg/sentry/platform/mmap_min_addr.go +++ b/pkg/sentry/platform/mmap_min_addr.go @@ -20,7 +20,7 @@ import ( "strconv" "strings" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // systemMMapMinAddrSource is the source file. @@ -30,8 +30,8 @@ const systemMMapMinAddrSource = "/proc/sys/vm/mmap_min_addr" var systemMMapMinAddr uint64 // SystemMMapMinAddr returns the minimum system address. -func SystemMMapMinAddr() usermem.Addr { - return usermem.Addr(systemMMapMinAddr) +func SystemMMapMinAddr() hostarch.Addr { + return hostarch.Addr(systemMMapMinAddr) } // MMapMinAddr is a size zero struct that implements MinUserAddress based on @@ -41,7 +41,7 @@ type MMapMinAddr struct { } // MinUserAddress implements platform.MinUserAddresss. -func (*MMapMinAddr) MinUserAddress() usermem.Addr { +func (*MMapMinAddr) MinUserAddress() hostarch.Addr { return SystemMMapMinAddr() } diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index dcfe839a7..ef7814a6f 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/hostmm" @@ -62,16 +63,16 @@ type Platform interface { // for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates // that the cost of AddressSpace.MapFile is effectively independent of the // number of pages mapped. If MapUnit is non-zero, it must be a power-of-2 - // multiple of usermem.PageSize. + // multiple of hostarch.PageSize. MapUnit() uint64 // MinUserAddress returns the minimum mappable address on this // platform. - MinUserAddress() usermem.Addr + MinUserAddress() hostarch.Addr // MaxUserAddress returns the maximum mappable address on this // platform. - MaxUserAddress() usermem.Addr + MaxUserAddress() hostarch.Addr // NewAddressSpace returns a new memory context for this platform. // @@ -172,7 +173,7 @@ type MemoryManager interface { //usermem.IO provides access to the contents of a virtual memory space. usermem.IO // MMap establishes a memory mapping. - MMap(ctx context.Context, opts memmap.MMapOpts) (usermem.Addr, error) + MMap(ctx context.Context, opts memmap.MMapOpts) (hostarch.Addr, error) // AddressSpace returns the AddressSpace bound to mm. AddressSpace() AddressSpace } @@ -195,7 +196,7 @@ type Context interface { // // - ErrContextSignal: The Context was interrupted by a signal. The // returned *arch.SignalInfo contains information about the signal. If - // arch.SignalInfo.Signo == SIGSEGV, the returned usermem.AccessType + // arch.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType // contains the access type of the triggering fault. The caller owns // the returned SignalInfo. // @@ -206,7 +207,7 @@ type Context interface { // concurrent call to Switch(). // // - ErrContextCPUPreempted: See the definition of that error for details. - Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) + Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) // PullFullState() pulls a full state of the application thread. // @@ -302,14 +303,14 @@ type AddressSpace interface { // * at.Any() == true. // * At least one reference must be held on all pages in fr, and must // continue to be held as long as pages are mapped. - MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error + MapFile(addr hostarch.Addr, f memmap.File, fr memmap.FileRange, at hostarch.AccessType, precommit bool) error // Unmap unmaps the given range. // // Preconditions: // * addr is page-aligned. // * length > 0. - Unmap(addr usermem.Addr, length uint64) + Unmap(addr hostarch.Addr, length uint64) // Release releases this address space. After releasing, a new AddressSpace // must be acquired via platform.NewAddressSpace(). @@ -337,67 +338,67 @@ type AddressSpaceIO interface { // CopyOut copies len(src) bytes from src to the memory mapped at addr. It // returns the number of bytes copied. If the number of bytes copied is < // len(src), it returns a non-nil error explaining why. - CopyOut(addr usermem.Addr, src []byte) (int, error) + CopyOut(addr hostarch.Addr, src []byte) (int, error) // CopyIn copies len(dst) bytes from the memory mapped at addr to dst. // It returns the number of bytes copied. If the number of bytes copied is // < len(dst), it returns a non-nil error explaining why. - CopyIn(addr usermem.Addr, dst []byte) (int, error) + CopyIn(addr hostarch.Addr, dst []byte) (int, error) // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a // non-nil error explaining why. - ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error) + ZeroOut(addr hostarch.Addr, toZero uintptr) (uintptr, error) // SwapUint32 atomically sets the uint32 value at addr to new and returns // the previous value. // // Preconditions: addr must be aligned to a 4-byte boundary. - SwapUint32(addr usermem.Addr, new uint32) (uint32, error) + SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) // CompareAndSwapUint32 atomically compares the uint32 value at addr to // old; if they are equal, the value in memory is replaced by new. In // either case, the previous value stored in memory is returned. // // Preconditions: addr must be aligned to a 4-byte boundary. - CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) + CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) // LoadUint32 atomically loads the uint32 value at addr and returns it. // // Preconditions: addr must be aligned to a 4-byte boundary. - LoadUint32(addr usermem.Addr) (uint32, error) + LoadUint32(addr hostarch.Addr) (uint32, error) } // NoAddressSpaceIO implements AddressSpaceIO methods by panicking. type NoAddressSpaceIO struct{} // CopyOut implements AddressSpaceIO.CopyOut. -func (NoAddressSpaceIO) CopyOut(addr usermem.Addr, src []byte) (int, error) { +func (NoAddressSpaceIO) CopyOut(addr hostarch.Addr, src []byte) (int, error) { panic("This platform does not support AddressSpaceIO") } // CopyIn implements AddressSpaceIO.CopyIn. -func (NoAddressSpaceIO) CopyIn(addr usermem.Addr, dst []byte) (int, error) { +func (NoAddressSpaceIO) CopyIn(addr hostarch.Addr, dst []byte) (int, error) { panic("This platform does not support AddressSpaceIO") } // ZeroOut implements AddressSpaceIO.ZeroOut. -func (NoAddressSpaceIO) ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error) { +func (NoAddressSpaceIO) ZeroOut(addr hostarch.Addr, toZero uintptr) (uintptr, error) { panic("This platform does not support AddressSpaceIO") } // SwapUint32 implements AddressSpaceIO.SwapUint32. -func (NoAddressSpaceIO) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) { +func (NoAddressSpaceIO) SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) { panic("This platform does not support AddressSpaceIO") } // CompareAndSwapUint32 implements AddressSpaceIO.CompareAndSwapUint32. -func (NoAddressSpaceIO) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) { +func (NoAddressSpaceIO) CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) { panic("This platform does not support AddressSpaceIO") } // LoadUint32 implements AddressSpaceIO.LoadUint32. -func (NoAddressSpaceIO) LoadUint32(addr usermem.Addr) (uint32, error) { +func (NoAddressSpaceIO) LoadUint32(addr hostarch.Addr) (uint32, error) { panic("This platform does not support AddressSpaceIO") } @@ -406,7 +407,7 @@ func (NoAddressSpaceIO) LoadUint32(addr usermem.Addr) (uint32, error) { // permissions. type SegmentationFault struct { // Addr is the address at which the fault occurred. - Addr usermem.Addr + Addr hostarch.Addr } // Error implements error.Error. diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 47efde6a2..d101f2f53 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -25,6 +25,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/procid", "//pkg/safecopy", @@ -35,7 +36,6 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sync", - "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 571bfcc2e..828458ce2 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -49,11 +49,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" pkgcontext "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) var ( @@ -88,28 +88,28 @@ type context struct { // lastFaultAddr is the last faulting address; this is only meaningful if // lastFaultSP is non-nil. - lastFaultAddr usermem.Addr + lastFaultAddr hostarch.Addr // lastFaultIP is the address of the last faulting instruction; // this is also only meaningful if lastFaultSP is non-nil. - lastFaultIP usermem.Addr + lastFaultIP hostarch.Addr } // Switch runs the provided context in the given address space. -func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) { +func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) { as := mm.AddressSpace() s := as.(*subprocess) isSyscall := s.switchToApp(c, ac) var ( faultSP *subprocess - faultAddr usermem.Addr - faultIP usermem.Addr + faultAddr hostarch.Addr + faultIP hostarch.Addr ) if !isSyscall && linux.Signal(c.signalInfo.Signo) == linux.SIGSEGV { faultSP = s - faultAddr = usermem.Addr(c.signalInfo.Addr()) - faultIP = usermem.Addr(ac.IP()) + faultAddr = hostarch.Addr(c.signalInfo.Addr()) + faultIP = hostarch.Addr(ac.IP()) } // Update the context to reflect the outcome of this context switch. @@ -140,14 +140,14 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a } if isSyscall { - return nil, usermem.NoAccess, nil + return nil, hostarch.NoAccess, nil } si := c.signalInfo if faultSP == nil { // Non-fault signal. - return &si, usermem.NoAccess, platform.ErrContextSignal + return &si, hostarch.NoAccess, platform.ErrContextSignal } // Got a page fault. Ideally, we'd get real fault type here, but ptrace @@ -157,7 +157,7 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a // pointer. // // It was a write fault if the fault is immediately repeated. - at := usermem.Read + at := hostarch.Read if faultAddr == faultIP { at.Execute = true } @@ -235,8 +235,8 @@ func (*PTrace) MapUnit() uint64 { // MaxUserAddress returns the first address that may not be used by user // applications. -func (*PTrace) MaxUserAddress() usermem.Addr { - return usermem.Addr(stubStart) +func (*PTrace) MaxUserAddress() hostarch.Addr { + return hostarch.Addr(stubStart) } // NewAddressSpace returns a new subprocess. diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 01e73b019..facb96011 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -19,9 +19,9 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" - "gvisor.dev/gvisor/pkg/usermem" ) // getRegs gets the general purpose register set. @@ -122,7 +122,7 @@ func (t *thread) getSignalInfo(si *arch.SignalInfo) error { // // Precondition: the OS thread must be locked and own t. func (t *thread) clone() (*thread, error) { - r, ok := usermem.Addr(stackPointer(&t.initRegs)).RoundUp() + r, ok := hostarch.Addr(stackPointer(&t.initRegs)).RoundUp() if !ok { return nil, unix.EINVAL } diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s index 16f9c523e..d5c3f901f 100644 --- a/pkg/sentry/platform/ptrace/stub_amd64.s +++ b/pkg/sentry/platform/ptrace/stub_amd64.s @@ -109,6 +109,12 @@ parent_dead: SYSCALL HLT +// func addrOfStub() uintptr +TEXT ·addrOfStub(SB), $0-8 + MOVQ $·stub(SB), AX + MOVQ AX, ret+0(FP) + RET + // stubCall calls the stub function at the given address with the given PPID. // // This is a distinct function because stub, above, may be mapped at any diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s index 6162df02a..4664cd4ad 100644 --- a/pkg/sentry/platform/ptrace/stub_arm64.s +++ b/pkg/sentry/platform/ptrace/stub_arm64.s @@ -102,6 +102,12 @@ parent_dead: SVC HLT +// func addrOfStub() uintptr +TEXT ·addrOfStub(SB), $0-8 + MOVD $·stub(SB), R0 + MOVD R0, ret+0(FP) + RET + // stubCall calls the stub function at the given address with the given PPID. // // This is a distinct function because stub, above, may be mapped at any diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go index 780227248..1fbdea898 100644 --- a/pkg/sentry/platform/ptrace/stub_unsafe.go +++ b/pkg/sentry/platform/ptrace/stub_unsafe.go @@ -19,13 +19,20 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safecopy" - "gvisor.dev/gvisor/pkg/usermem" ) // stub is defined in arch-specific assembly. func stub() +// addrOfStub returns the start address of stub. +// +// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal +// wrapper function rather than the function itself. We must reference from +// assembly to get the ABI0 (i.e., primary) address. +func addrOfStub() uintptr + // stubCall calls the stub at the given address with the given pid. func stubCall(addr, pid uintptr) @@ -41,12 +48,12 @@ func unsafeSlice(addr uintptr, length int) (slice []byte) { // stubInit initializes the stub. func stubInit() { // Grab the existing stub. - stubBegin := reflect.ValueOf(stub).Pointer() + stubBegin := addrOfStub() stubLen := int(safecopy.FindEndAddress(stubBegin) - stubBegin) stubSlice := unsafeSlice(stubBegin, stubLen) mapLen := uintptr(stubLen) - if offset := mapLen % usermem.PageSize; offset != 0 { - mapLen += usermem.PageSize - offset + if offset := mapLen % hostarch.PageSize; offset != 0 { + mapLen += hostarch.PageSize - offset } for stubStart > 0 { @@ -70,7 +77,7 @@ func stubInit() { } // Attempt to begin at a lower address. - stubStart -= uintptr(usermem.PageSize) + stubStart -= uintptr(hostarch.PageSize) continue } diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index acccbfe2e..9c73a725a 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -20,13 +20,13 @@ import ( "runtime" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" ) // Linux kernel errnos which "should never be seen by user programs", but will @@ -69,7 +69,7 @@ type thread struct { // threadPool is a collection of threads. type threadPool struct { // mu protects below. - mu sync.Mutex + mu sync.RWMutex // threads is the collection of threads. // @@ -85,30 +85,42 @@ type threadPool struct { // // Precondition: the runtime OS thread must be locked. func (tp *threadPool) lookupOrCreate(currentTID int32, newThread func() *thread) *thread { - tp.mu.Lock() + // The overwhelming common case is that the thread is already created. + // Optimistically attempt the lookup by only locking for reading. + tp.mu.RLock() t, ok := tp.threads[currentTID] - if !ok { - // Before creating a new thread, see if we can find a thread - // whose system tid has disappeared. - // - // TODO(b/77216482): Other parts of this package depend on - // threads never exiting. - for origTID, t := range tp.threads { - // Signal zero is an easy existence check. - if err := unix.Tgkill(unix.Getpid(), int(origTID), 0); err != nil { - // This thread has been abandoned; reuse it. - delete(tp.threads, origTID) - tp.threads[currentTID] = t - tp.mu.Unlock() - return t - } - } + tp.mu.RUnlock() + if ok { + return t + } - // Create a new thread. - t = newThread() - tp.threads[currentTID] = t + tp.mu.Lock() + defer tp.mu.Unlock() + + // Another goroutine might have created the thread for currentTID in between + // mu.RUnlock() and mu.Lock(). + if t, ok = tp.threads[currentTID]; ok { + return t + } + + // Before creating a new thread, see if we can find a thread + // whose system tid has disappeared. + // + // TODO(b/77216482): Other parts of this package depend on + // threads never exiting. + for origTID, t := range tp.threads { + // Signal zero is an easy existence check. + if err := unix.Tgkill(unix.Getpid(), int(origTID), 0); err != nil { + // This thread has been abandoned; reuse it. + delete(tp.threads, origTID) + tp.threads[currentTID] = t + return t + } } - tp.mu.Unlock() + + // Create a new thread. + t = newThread() + tp.threads[currentTID] = t return t } @@ -228,7 +240,7 @@ func newSubprocess(create func() (*thread, error)) (*subprocess, error) { func (s *subprocess) unmap() { s.Unmap(0, uint64(stubStart)) if maximumUserAddress != stubEnd { - s.Unmap(usermem.Addr(stubEnd), uint64(maximumUserAddress-stubEnd)) + s.Unmap(hostarch.Addr(stubEnd), uint64(maximumUserAddress-stubEnd)) } } @@ -615,7 +627,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp } // MapFile implements platform.AddressSpace.MapFile. -func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { +func (s *subprocess) MapFile(addr hostarch.Addr, f memmap.File, fr memmap.FileRange, at hostarch.AccessType, precommit bool) error { var flags int if precommit { flags |= unix.MAP_POPULATE @@ -632,7 +644,7 @@ func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRan } // Unmap implements platform.AddressSpace.Unmap. -func (s *subprocess) Unmap(addr usermem.Addr, length uint64) { +func (s *subprocess) Unmap(addr hostarch.Addr, length uint64) { ar, ok := addr.ToRange(length) if !ok { panic(fmt.Sprintf("addr %#x + length %#x overflows", addr, length)) diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 0ce42b6cc..080859125 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/hostarch", "//pkg/marshal", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index ebcc891b3..0e0e82365 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/hostarch", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -23,7 +24,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserror", - "//pkg/usermem", ], ) @@ -35,8 +35,8 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/hostarch", "//pkg/sentry/socket", - "//pkg/usermem", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 65b556489..45a05cd63 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -20,13 +20,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) const maxInt = int(^uint(0) >> 1) @@ -181,12 +181,12 @@ func (c *scmCredentials) Equals(oc transport.CredentialsControlMessage) bool { } func putUint64(buf []byte, n uint64) []byte { - usermem.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n) + hostarch.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n) return buf[:len(buf)+8] } func putUint32(buf []byte, n uint32) []byte { - usermem.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n) + hostarch.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n) return buf[:len(buf)+4] } @@ -242,7 +242,7 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf hdrBuf := buf - buf = binary.Marshal(buf, usermem.ByteOrder, data) + buf = binary.Marshal(buf, hostarch.ByteOrder, data) // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { @@ -475,7 +475,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h) + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, syserror.EINVAL @@ -499,7 +499,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) } i += binary.AlignUp(length, width) @@ -510,7 +510,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err @@ -523,7 +523,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } var ts linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &ts) + binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &ts) cmsgs.IP.Timestamp = ts.ToNsecCapped() cmsgs.IP.HasTimestamp = true i += binary.AlignUp(length, width) @@ -539,7 +539,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTOS = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &cmsgs.IP.TOS) i += binary.AlignUp(length, width) case linux.IP_PKTINFO: @@ -549,7 +549,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) cmsgs.IP.PacketInfo = packetInfo i += binary.AlignUp(length, width) @@ -559,7 +559,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) @@ -583,7 +583,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTClass = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &cmsgs.IP.TClass) i += binary.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: @@ -591,7 +591,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go index d40a4cc85..7e28a0cef 100644 --- a/pkg/sentry/socket/control/control_test.go +++ b/pkg/sentry/socket/control/control_test.go @@ -22,8 +22,8 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/usermem" ) func TestParse(t *testing.T) { @@ -35,12 +35,12 @@ func TestParse(t *testing.T) { Type: linux.SO_TIMESTAMP, } buf := make([]byte, 0, length) - buf = binary.Marshal(buf, usermem.ByteOrder, &hdr) + buf = binary.Marshal(buf, hostarch.ByteOrder, &hdr) ts := linux.Timeval{ Sec: 2401, Usec: 343, } - buf = binary.Marshal(buf, usermem.ByteOrder, &ts) + buf = binary.Marshal(buf, hostarch.ByteOrder, &ts) cmsg, err := Parse(nil, nil, buf, 8 /* width */) if err != nil { diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a8e6f172b..a5c2155a2 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/binary", "//pkg/context", "//pkg/fdnotifier", + "//pkg/hostarch", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 2d9dbbdba..a784e23b5 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -321,7 +322,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -527,24 +528,24 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp) + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], hostarch.ByteOrder, &controlMessages.IP.Timestamp) } case linux.SOL_IP: switch unixCmsg.Header.Type { case linux.IP_TOS: controlMessages.IP.HasTOS = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &controlMessages.IP.TOS) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -557,11 +558,11 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &controlMessages.IP.TClass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -574,7 +575,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.TCP_INQ: controlMessages.IP.HasInq = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq) + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], hostarch.ByteOrder, &controlMessages.IP.Inq) } } } @@ -688,7 +689,7 @@ func (s *socketOpsCommon) State() uint32 { return 0 } - binary.Unmarshal(buf, usermem.ByteOrder, &info) + binary.Unmarshal(buf, hostarch.ByteOrder, &info) return uint32(info.State) } diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index 2890e640d..d3be2d825 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -61,7 +62,7 @@ func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArgument return 0, translateIOSyscallError(errno) } var buf [4]byte - usermem.ByteOrder.PutUint32(buf[:], uint32(val)) + hostarch.ByteOrder.PutUint32(buf[:], uint32(val)) _, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{ AddressSpaceActive: true, }) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 5bcf92e14..26e8ae17a 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -22,11 +22,13 @@ import ( "reflect" "strconv" "strings" + "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" @@ -146,7 +148,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } var ifinfo unix.IfInfomsg - binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], usermem.ByteOrder, &ifinfo) + binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], hostarch.ByteOrder, &ifinfo) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -177,7 +179,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } var ifaddr unix.IfAddrmsg - binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], usermem.ByteOrder, &ifaddr) + binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], hostarch.ByteOrder, &ifaddr) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, PrefixLen: ifaddr.Prefixlen, @@ -209,7 +211,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) } var ifRoute unix.RtMsg - binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], usermem.ByteOrder, &ifRoute) + binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], hostarch.ByteOrder, &ifRoute) inetRoute := inet.Route{ Family: ifRoute.Family, DstLen: ifRoute.Dst_len, @@ -243,7 +245,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) if len(attr.Value) != expected { return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected) } - binary.Unmarshal(attr.Value, usermem.ByteOrder, &inetRoute.OutputInterface) + binary.Unmarshal(attr.Value, hostarch.ByteOrder, &inetRoute.OutputInterface) } } diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 8aea0200f..4381dfa06 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -20,12 +20,12 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/hostarch", "//pkg/log", "//pkg/sentry/kernel", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/header", "//pkg/tcpip/stack", - "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index e339f9bea..4bd305a44 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -19,10 +19,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/usermem" ) // TODO(gvisor.dev/issue/170): The following per-matcher params should be @@ -89,7 +89,7 @@ func marshalEntryMatch(name string, data []byte) []byte { copy(matcher.Name[:], name) buf := make([]byte, 0, size) - buf = binary.Marshal(buf, usermem.ByteOrder, matcher) + buf = binary.Marshal(buf, hostarch.ByteOrder, matcher) return append(buf, make([]byte, size-len(buf))...) } diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 2f913787b..1fc4cb651 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -19,11 +19,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/usermem" ) // emptyIPv4Filter is for comparison with a rule's filters to determine whether @@ -142,7 +142,7 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, } var entry linux.IPTEntry buf := optVal[:linux.SizeOfIPTEntry] - binary.Unmarshal(buf, usermem.ByteOrder, &entry) + binary.Unmarshal(buf, hostarch.ByteOrder, &entry) initialOptValLen := len(optVal) optVal = optVal[linux.SizeOfIPTEntry:] diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 263d9d3b5..67a52b628 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -19,11 +19,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/usermem" ) // emptyIPv6Filter is for comparison with a rule's filters to determine whether @@ -145,7 +145,7 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, } var entry linux.IP6TEntry buf := optVal[:linux.SizeOfIP6TEntry] - binary.Unmarshal(buf, usermem.ByteOrder, &entry) + binary.Unmarshal(buf, hostarch.ByteOrder, &entry) initialOptValLen := len(optVal) optVal = optVal[linux.SizeOfIP6TEntry:] diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 7ae18b2a3..c6fa3fd16 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -23,12 +23,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/usermem" ) // enableLogging controls whether to log the (de)serialization of netfilter @@ -83,7 +83,7 @@ func DefaultLinuxTables() *stack.IPTables { } // GetInfo returns information about iptables. -func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { +func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo if _, err := info.CopyIn(t, outPtr); err != nil { @@ -106,7 +106,7 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) } // GetEntries4 returns netstack's iptables rules. -func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { +func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries if _, err := userEntries.CopyIn(t, outPtr); err != nil { @@ -130,7 +130,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen } // GetEntries6 returns netstack's ip6tables rules. -func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) { +func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) { // Read in the struct and table name. IPv4 and IPv6 utilize structs // with the same layout. var userEntries linux.IPTGetEntries @@ -179,7 +179,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] - binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace) + binary.Unmarshal(replaceBuf, hostarch.ByteOrder, &replace) // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table @@ -274,10 +274,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } // TODO(gvisor.dev/issue/170): Support other chains. - // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, - // make sure all other chains point to ACCEPT rules. + // Since we don't support FORWARD, yet, make sure all other chains point to + // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if hook := stack.Hook(hook); hook == stack.Forward { if ruleIdx == stack.HookUnset { continue } @@ -310,7 +310,7 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, } var match linux.XTEntryMatch buf := optVal[:linux.SizeOfXTEntryMatch] - binary.Unmarshal(buf, usermem.ByteOrder, &match) + binary.Unmarshal(buf, hostarch.ByteOrder, &match) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. @@ -381,7 +381,7 @@ func hookFromLinux(hook int) stack.Hook { // TargetRevision returns a linux.XTGetRevision for a given target. It sets // Revision to the highest supported value, unless the provided revision number // is larger. -func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) { +func TargetRevision(t *kernel.Task, revPtr hostarch.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) { // Read in the target name and version. var rev linux.XTGetRevision if _, err := rev.CopyIn(t, revPtr); err != nil { diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 5f80d82ea..b2cc6be20 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/usermem" ) const matcherNameOwner = "owner" @@ -60,7 +60,7 @@ func (ownerMarshaler) marshal(mr matcher) []byte { } buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) - return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo)) + return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, hostarch.ByteOrder, iptOwnerInfo)) } // unmarshal implements matchMaker.unmarshal. @@ -72,7 +72,7 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData) + binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], hostarch.ByteOrder, &matchData) nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index f2653d523..38b6491e2 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -19,11 +19,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/usermem" ) // ErrorTargetName is used to mark targets as error targets. Error targets @@ -35,6 +35,11 @@ const ErrorTargetName = "ERROR" // change the destination port and/or IP for packets. const RedirectTargetName = "REDIRECT" +// SNATTargetName is used to mark targets as SNAT targets. SNAT targets should +// be reached for only NAT table. These targets will change the source port +// and/or IP for packets. +const SNATTargetName = "SNAT" + func init() { // Standard targets include ACCEPT, DROP, RETURN, and JUMP. registerTargetMaker(&standardTargetMaker{ @@ -59,6 +64,13 @@ func init() { registerTargetMaker(&nfNATTargetMaker{ NetworkProtocol: header.IPv6ProtocolNumber, }) + + registerTargetMaker(&snatTargetMakerV4{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&snatTargetMakerV6{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) } // The stack package provides some basic, useful targets for us. The following @@ -131,6 +143,17 @@ func (rt *redirectTarget) id() targetID { } } +type snatTarget struct { + stack.SNATTarget +} + +func (st *snatTarget) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + } +} + type standardTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } @@ -167,7 +190,7 @@ func (*standardTargetMaker) marshal(target target) []byte { } ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, xt) + return binary.Marshal(ret, hostarch.ByteOrder, xt) } func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -177,7 +200,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( } var standardTarget linux.XTStandardTarget buf = buf[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) + binary.Unmarshal(buf, hostarch.ByteOrder, &standardTarget) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -223,7 +246,7 @@ func (*errorTargetMaker) marshal(target target) []byte { copy(xt.Target.Name[:], ErrorTargetName) ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, usermem.ByteOrder, xt) + return binary.Marshal(ret, hostarch.ByteOrder, xt) } func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -233,7 +256,7 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errTgt) + binary.Unmarshal(buf, hostarch.ByteOrder, &errTgt) // Error targets are used in 2 cases: // * An actual error case. These rules have an error named @@ -281,7 +304,7 @@ func (*redirectTargetMaker) marshal(target target) []byte { xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED xt.NfRange.RangeIPV4.MinPort = htons(rt.Port) xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort - return binary.Marshal(ret, usermem.ByteOrder, xt) + return binary.Marshal(ret, hostarch.ByteOrder, xt) } func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -297,7 +320,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &rt) + binary.Unmarshal(buf, hostarch.ByteOrder, &rt) // Copy linux.XTRedirectTarget to stack.RedirectTarget. target := redirectTarget{RedirectTarget: stack.RedirectTarget{ @@ -341,7 +364,7 @@ type nfNATTarget struct { Range linux.NFNATRange } -const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange +const nfNATMarshalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange type nfNATTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber @@ -358,7 +381,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte { rt := target.(*redirectTarget) nt := nfNATTarget{ Target: linux.XTEntryTarget{ - TargetSize: nfNATMarhsalledSize, + TargetSize: nfNATMarshalledSize, }, Range: linux.NFNATRange{ Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, @@ -371,12 +394,12 @@ func (*nfNATTargetMaker) marshal(target target) []byte { nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto - ret := make([]byte, 0, nfNATMarhsalledSize) - return binary.Marshal(ret, usermem.ByteOrder, nt) + ret := make([]byte, 0, nfNATMarshalledSize) + return binary.Marshal(ret, hostarch.ByteOrder, nt) } func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { - if size := nfNATMarhsalledSize; len(buf) < size { + if size := nfNATMarshalledSize; len(buf) < size { nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) return nil, syserr.ErrInvalidArgument } @@ -387,8 +410,8 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize] - binary.Unmarshal(buf, usermem.ByteOrder, &natRange) + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] + binary.Unmarshal(buf, hostarch.ByteOrder, &natRange) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -418,6 +441,161 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return &target, nil } +type snatTargetMakerV4 struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (st *snatTargetMakerV4) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + } +} + +func (*snatTargetMakerV4) marshal(target target) []byte { + st := target.(*snatTarget) + // This is a snat target named snat. + xt := linux.XTSNATTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTSNATTarget, + }, + } + copy(xt.Target.Name[:], SNATTargetName) + + xt.NfRange.RangeSize = 1 + xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED + xt.NfRange.RangeIPV4.MinPort = htons(st.Port) + xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort + copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr) + copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr) + ret := make([]byte, 0, linux.SizeOfXTSNATTarget) + return binary.Marshal(ret, hostarch.ByteOrder, xt) +} + +func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { + if len(buf) < linux.SizeOfXTSNATTarget { + nflog("snatTargetMakerV4: buf has insufficient size for snat target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("snatTargetMakerV4: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var st linux.XTSNATTarget + buf = buf[:linux.SizeOfXTSNATTarget] + binary.Unmarshal(buf, hostarch.ByteOrder, &st) + + // Copy linux.XTSNATTarget to stack.SNATTarget. + target := snatTarget{SNATTarget: stack.SNATTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }} + + // RangeSize should be 1. + nfRange := st.NfRange + if nfRange.RangeSize != 1 { + nflog("snatTargetMakerV4: bad rangesize %d", nfRange.RangeSize) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/5772): If the rule doesn't specify the source port, + // choose one automatically. + if nfRange.RangeIPV4.MinPort == 0 { + nflog("snatTargetMakerV4: snat target needs to specify a non-zero port") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP { + nflog("snatTargetMakerV4: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + + target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.Port = ntohs(nfRange.RangeIPV4.MinPort) + + return &target, nil +} + +type snatTargetMakerV6 struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (st *snatTargetMakerV6) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + revision: 1, + } +} + +func (*snatTargetMakerV6) marshal(target target) []byte { + st := target.(*snatTarget) + nt := nfNATTarget{ + Target: linux.XTEntryTarget{ + TargetSize: nfNATMarshalledSize, + }, + Range: linux.NFNATRange{ + Flags: linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED, + }, + } + copy(nt.Target.Name[:], SNATTargetName) + copy(nt.Range.MinAddr[:], st.Addr) + copy(nt.Range.MaxAddr[:], st.Addr) + nt.Range.MinProto = htons(st.Port) + nt.Range.MaxProto = nt.Range.MinProto + + ret := make([]byte, 0, nfNATMarshalledSize) + return binary.Marshal(ret, hostarch.ByteOrder, nt) +} + +func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { + if size := nfNATMarshalledSize; len(buf) < size { + nflog("snatTargetMakerV6: buf has insufficient size (%d) for SNAT V6 target (%d)", len(buf), size) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("snatTargetMakerV6: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var natRange linux.NFNATRange + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] + binary.Unmarshal(buf, hostarch.ByteOrder, &natRange) + + // TODO(gvisor.dev/issue/5689): Support port or address ranges. + if natRange.MinAddr != natRange.MaxAddr { + nflog("snatTargetMakerV6: MinAddr and MaxAddr are different") + return nil, syserr.ErrInvalidArgument + } + if natRange.MinProto != natRange.MaxProto { + nflog("snatTargetMakerV6: MinProto and MaxProto are different") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/5698): Support other NF_NAT_RANGE flags. + if natRange.Flags != linux.NF_NAT_RANGE_MAP_IPS|linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("snatTargetMakerV6: invalid range flags %d", natRange.Flags) + return nil, syserr.ErrInvalidArgument + } + + target := snatTarget{ + SNATTarget: stack.SNATTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Addr: tcpip.Address(natRange.MinAddr[:]), + Port: ntohs(natRange.MinProto), + }, + } + + return &target, nil +} + // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { @@ -454,7 +632,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T } var target linux.XTEntryTarget buf := optVal[:linux.SizeOfXTEntryTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &target) + binary.Unmarshal(buf, hostarch.ByteOrder, &target) return unmarshalTarget(target, filter, optVal) } @@ -487,11 +665,11 @@ func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, func ntohs(port uint16) uint16 { buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, port) - return usermem.ByteOrder.Uint16(buf) + return hostarch.ByteOrder.Uint16(buf) } func htons(port uint16) uint16 { buf := make([]byte, 2) - usermem.ByteOrder.PutUint16(buf, port) + hostarch.ByteOrder.PutUint16(buf, port) return binary.BigEndian.Uint16(buf) } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 678d6b578..69557f515 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -19,9 +19,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/usermem" ) const matcherNameTCP = "tcp" @@ -48,7 +48,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte { DestinationPortEnd: matcher.destinationPortEnd, } buf := make([]byte, 0, linux.SizeOfXTTCP) - return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp)) + return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, hostarch.ByteOrder, xttcp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +60,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData) + binary.Unmarshal(buf[:linux.SizeOfXTTCP], hostarch.ByteOrder, &matchData) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index f8568873f..6a60e6bd6 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -19,9 +19,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/usermem" ) const matcherNameUDP = "udp" @@ -48,7 +48,7 @@ func (udpMarshaler) marshal(mr matcher) []byte { DestinationPortEnd: matcher.destinationPortEnd, } buf := make([]byte, 0, linux.SizeOfXTUDP) - return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp)) + return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, hostarch.ByteOrder, xtudp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +60,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData) + binary.Unmarshal(buf[:linux.SizeOfXTUDP], hostarch.ByteOrder, &matchData) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 9313e1167..171b95c63 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/hostarch", "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/sentry/arch", diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index 0899c61d1..ab0e68af7 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -20,7 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // alignPad returns the length of padding required for alignment. @@ -42,7 +42,7 @@ type Message struct { func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ hdr: hdr, - buf: binary.Marshal(nil, usermem.ByteOrder, hdr), + buf: binary.Marshal(nil, hostarch.ByteOrder, hdr), } } @@ -58,7 +58,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { return } var hdr linux.NetlinkMessageHeader - binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) // Msg portion. totalMsgLen := int(hdr.Length) @@ -105,7 +105,7 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) { if !ok { return nil, false } - binary.Unmarshal(msgBytes, usermem.ByteOrder, msg) + binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg) numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) // Linux permits the last message not being aligned, just consume all of it. @@ -126,7 +126,7 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) { // calling Finalize. func (m *Message) Finalize() []byte { // Update length, which is the first 4 bytes of the header. - usermem.ByteOrder.PutUint32(m.buf, uint32(len(m.buf))) + hostarch.ByteOrder.PutUint32(m.buf, uint32(len(m.buf))) // Align the message. Note that the message length in the header (set // above) is the useful length of the message, not the total aligned @@ -146,7 +146,7 @@ func (m *Message) putZeros(n int) { // Put serializes v into the message. func (m *Message) Put(v interface{}) { - m.buf = binary.Marshal(m.buf, usermem.ByteOrder, v) + m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v) } // PutAttr adds v to the message as a netlink attribute. @@ -251,7 +251,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest if !ok { return } - binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) if !ok { diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d5ffc75ce..30c297149 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -222,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - binary.Unmarshal(b[:linux.SockAddrNetlinkSize], usermem.ByteOrder, &sa) + binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument @@ -327,7 +328,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -388,7 +389,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] if len(opt) < sizeOfInt32 { return syserr.ErrInvalidArgument } - size := usermem.ByteOrder.Uint32(opt) + size := hostarch.ByteOrder.Uint32(opt) if size < minSendBufferSize { size = minSendBufferSize } else if size > maxSendBufferSize { @@ -411,7 +412,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] if len(opt) < sizeOfInt32 { return syserr.ErrInvalidArgument } - passcred := usermem.ByteOrder.Uint32(opt) + passcred := hostarch.ByteOrder.Uint32(opt) s.ep.SocketOptions().SetPassCred(passcred != 0) return nil diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 244d99436..0b39a5b67 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 64e70ab9d..312f5f85a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -37,6 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -241,6 +242,7 @@ var Metrics = tcpip.Stats{ FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."), Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."), ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), + FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."), }, UDP: tcpip.UDPStats{ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."), @@ -600,7 +602,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { return syserr.ErrInvalidArgument } - family := usermem.ByteOrder.Uint16(sockaddr) + family := hostarch.ByteOrder.Uint16(sockaddr) var addr tcpip.FullAddress // Bind for AF_PACKET requires only family, protocol and ifindex. @@ -611,7 +613,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a) + binary.Unmarshal(sockaddr[:sockAddrLinkSize], hostarch.ByteOrder, &a) if a.Protocol != uint16(s.protocol) { return syserr.ErrInvalidArgument @@ -757,7 +759,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -793,7 +795,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -884,10 +886,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } + size := ep.SocketOptions().GetReceiveBufferSize() if size > math.MaxInt32 { size = math.MaxInt32 @@ -1244,7 +1243,7 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if _, ok := ep.(tcpip.Endpoint); !ok { log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) return nil, syserr.ErrUnknownProtocolOption @@ -1392,7 +1391,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { +func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { if _, ok := ep.(tcpip.Endpoint); !ok { log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) return nil, syserr.ErrUnknownProtocolOption @@ -1602,7 +1601,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa } s.readMu.Lock() defer s.readMu.Unlock() - s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0 + s.sockOptTimestamp = hostarch.ByteOrder.Uint32(optVal) != 0 return nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { @@ -1611,7 +1610,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa } s.readMu.Lock() defer s.readMu.Unlock() - s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0 + s.sockOptInq = hostarch.ByteOrder.Uint32(optVal) != 0 return nil } @@ -1659,8 +1658,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) - ep.SocketOptions().SetSendBufferSize(int64(v), true) + v := hostarch.ByteOrder.Uint32(optVal) + ep.SocketOptions().SetSendBufferSize(int64(v), true /* notify */) return nil case linux.SO_RCVBUF: @@ -1668,15 +1667,16 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) + v := hostarch.ByteOrder.Uint32(optVal) + ep.SocketOptions().SetReceiveBufferSize(int64(v), true /* notify */) + return nil case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetReuseAddress(v != 0) return nil @@ -1685,7 +1685,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetReusePort(v != 0) return nil @@ -1714,7 +1714,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetBroadcast(v != 0) return nil @@ -1723,7 +1723,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetPassCred(v != 0) return nil @@ -1732,7 +1732,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetKeepAlive(v != 0) return nil @@ -1742,7 +1742,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v) + binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1755,7 +1755,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v) + binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1767,7 +1767,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) if v == 0 { socket.SetSockOptEmitUnimplementedEvent(t, name) @@ -1781,7 +1781,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetNoChecksum(v != 0) return nil @@ -1791,7 +1791,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - binary.Unmarshal(optVal[:linux.SizeOfLinger], usermem.ByteOrder, &v) + binary.Unmarshal(optVal[:linux.SizeOfLinger], hostarch.ByteOrder, &v) ep.SocketOptions().SetLinger(tcpip.LingerOption{ Enabled: v.OnOff != 0, @@ -1824,7 +1824,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetDelayOption(v == 0) return nil @@ -1833,7 +1833,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetCorkOption(v != 0) return nil @@ -1842,7 +1842,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetQuickAck(v != 0) return nil @@ -1851,7 +1851,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v))) case linux.TCP_KEEPIDLE: @@ -1859,7 +1859,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) if v < 1 || v > linux.MAX_TCP_KEEPIDLE { return syserr.ErrInvalidArgument } @@ -1871,7 +1871,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) if v < 1 || v > linux.MAX_TCP_KEEPINTVL { return syserr.ErrInvalidArgument } @@ -1883,7 +1883,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) if v < 1 || v > linux.MAX_TCP_KEEPCNT { return syserr.ErrInvalidArgument } @@ -1894,7 +1894,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := int32(usermem.ByteOrder.Uint32(optVal)) + v := int32(hostarch.ByteOrder.Uint32(optVal)) if v < 0 { return syserr.ErrInvalidArgument } @@ -1913,7 +1913,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i return syserr.ErrInvalidArgument } - v := int32(usermem.ByteOrder.Uint32(optVal)) + v := int32(hostarch.ByteOrder.Uint32(optVal)) opt := tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)) return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) @@ -1921,7 +1921,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } - v := int32(usermem.ByteOrder.Uint32(optVal)) + v := int32(hostarch.ByteOrder.Uint32(optVal)) if v < 0 { v = 0 } @@ -1932,7 +1932,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v))) @@ -1940,7 +1940,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v))) @@ -1978,7 +1978,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return syserr.ErrInvalidEndpointState } - v := usermem.ByteOrder.Uint32(optVal) + v := hostarch.ByteOrder.Uint32(optVal) ep.SocketOptions().SetV6Only(v != 0) return nil @@ -2024,7 +2024,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } - v := int32(usermem.ByteOrder.Uint32(optVal)) + v := int32(hostarch.ByteOrder.Uint32(optVal)) ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) return nil @@ -2033,7 +2033,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } - v := int32(usermem.ByteOrder.Uint32(optVal)) + v := int32(hostarch.ByteOrder.Uint32(optVal)) if v < -1 || v > 255 { return syserr.ErrInvalidArgument } @@ -2117,12 +2117,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], usermem.ByteOrder, &req) + binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], hostarch.ByteOrder, &req) return req, nil } var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestSize], usermem.ByteOrder, &req.InetMulticastRequest) + binary.Unmarshal(optVal[:inetMulticastRequestSize], hostarch.ByteOrder, &req.InetMulticastRequest) return req, nil } @@ -2132,7 +2132,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - binary.Unmarshal(optVal[:inet6MulticastRequestSize], usermem.ByteOrder, &req) + binary.Unmarshal(optVal[:inet6MulticastRequestSize], hostarch.ByteOrder, &req) return req, nil } @@ -2145,7 +2145,7 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { } if len(buf) >= sizeOfInt32 { - return int32(usermem.ByteOrder.Uint32(buf)), nil + return int32(hostarch.ByteOrder.Uint32(buf)), nil } return int32(buf[0]), nil @@ -3007,7 +3007,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. - index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) + index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4])) if iface, ok := stack.Interfaces()[index]; ok { ifr.SetName(iface.Name) return nil @@ -3029,7 +3029,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe switch arg { case linux.SIOCGIFINDEX: // Copy out the index to the data. - usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) + hostarch.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) case linux.SIOCGIFHWADDR: // Copy the hardware address out. @@ -3042,7 +3042,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // sockaddr. sa_family contains the ARPHRD_* device type, // sa_data the L2 hardware address starting from byte 0. Setting // the hardware address is a privileged operation. - usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) + hostarch.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. @@ -3055,7 +3055,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // Drop the flags that don't fit in the size that we need to return. This // matches Linux behavior. - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) + hostarch.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) case linux.SIOCGIFADDR: // Copy the IPv4 address out. @@ -3071,11 +3071,11 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. - usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) + hostarch.ByteOrder.PutUint32(ifr.Data[:4], 0) case linux.SIOCGIFMTU: // Gets the MTU of the device. - usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) + hostarch.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. @@ -3101,8 +3101,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe continue } // Populate ifr.ifr_netmask (type sockaddr). - usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET)) - usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0) + hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET)) + hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0) var mask uint32 = 0xffffffff << (32 - addr.PrefixLen) // Netmask is expected to be returned as a big endian // value. @@ -3157,14 +3157,14 @@ func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux. // Populate ifr.ifr_addr. ifr := linux.IFReq{} ifr.SetName(iface.Name) - usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family)) - usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0) + hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family)) + hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0) copy(ifr.Data[4:8], ifaceAddr.Addr[:4]) // Copy the ifr to userspace. dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) ifc.Len += int32(linux.SizeOfIFReq) - if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil { + if _, err := ifr.CopyOut(t, hostarch.Addr(dst)); err != nil { return err } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index fc29f8f13..30f3ad153 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -17,6 +17,7 @@ package netstack import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -197,7 +198,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -245,7 +246,7 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by } s.readMu.Lock() defer s.readMu.Unlock() - s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0 + s.sockOptTimestamp = hostarch.ByteOrder.Uint32(optVal) != 0 return nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { @@ -254,7 +255,7 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by } s.readMu.Lock() defer s.readMu.Unlock() - s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0 + s.sockOptInq = hostarch.ByteOrder.Uint32(optVal) != 0 return nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 909341dcf..4c3d48096 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -216,7 +217,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux unix. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux unix. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error @@ -356,7 +357,7 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent { Type: fs.Socket, DeviceID: d.DeviceID(), InodeID: ino, - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, }) // Dirent name matches net/socket.c:sockfs_dname. @@ -571,19 +572,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - binary.Unmarshal(data[:unix.SizeofSockaddrInet4], usermem.ByteOrder, &addr) + binary.Unmarshal(data[:unix.SizeofSockaddrInet4], hostarch.ByteOrder, &addr) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - binary.Unmarshal(data[:unix.SizeofSockaddrInet6], usermem.ByteOrder, &addr) + binary.Unmarshal(data[:unix.SizeofSockaddrInet6], hostarch.ByteOrder, &addr) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - binary.Unmarshal(data[:unix.SizeofSockaddrUnix], usermem.ByteOrder, &addr) + binary.Unmarshal(data[:unix.SizeofSockaddrUnix], hostarch.ByteOrder, &addr) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], usermem.ByteOrder, &addr) + binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], hostarch.ByteOrder, &addr) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -693,7 +694,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { } // Get the rest of the fields based on the address family. - switch family := usermem.ByteOrder.Uint16(addr); family { + switch family := hostarch.ByteOrder.Uint16(addr); family { case linux.AF_UNIX: path := addr[2:] if len(path) > linux.UnixPathMax { @@ -715,7 +716,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) + binary.Unmarshal(addr[:sockAddrInetSize], hostarch.ByteOrder, &a) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -728,7 +729,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) + binary.Unmarshal(addr[:sockAddrInet6Size], hostarch.ByteOrder, &a) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -744,7 +745,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + binary.Unmarshal(addr[:sockAddrLinkSize], hostarch.ByteOrder, &a) if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index ff53a26b7..c9cbefb3a 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -40,6 +40,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/hostarch", "//pkg/log", "//pkg/marshal", "//pkg/refs", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 159b8f90f..408dfb08d 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -130,7 +130,8 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv } ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } @@ -175,8 +176,9 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) return ep } @@ -299,8 +301,9 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } - ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits) + ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) + ne.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize} readQueue.InitRefs() @@ -366,6 +369,7 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint // to reflect this endpoint's send buffer size. if bufSz := e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()); bufSz != e.ops.GetSendBufferSize() { e.ops.SetSendBufferSize(bufSz, false /* notify */) + e.ops.SetReceiveBufferSize(bufSz, false /* notify */) } } diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go index 590b0bd01..b20334d4f 100644 --- a/pkg/sentry/socket/unix/transport/connectioned_state.go +++ b/pkg/sentry/socket/unix/transport/connectioned_state.go @@ -54,5 +54,5 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd // afterLoad is invoked by stateify. func (e *connectionedEndpoint) afterLoad() { - e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) } diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index d0df28b59..61338728a 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -45,7 +45,8 @@ func NewConnectionless(ctx context.Context) Endpoint { q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go index 2ef337ec8..1bb71baf7 100644 --- a/pkg/sentry/socket/unix/transport/connectionless_state.go +++ b/pkg/sentry/socket/unix/transport/connectionless_state.go @@ -16,5 +16,5 @@ package transport // afterLoad is invoked by stateify. func (e *connectionlessEndpoint) afterLoad() { - e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0c5f5ab42..837ab4fde 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -868,11 +868,7 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } + log.Warningf("Unsupported socket option: %d", opt) return nil } @@ -905,19 +901,6 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } return int(v), nil - case tcpip.ReceiveBufferSizeOption: - e.Lock() - if e.receiver == nil { - e.Unlock() - return -1, &tcpip.ErrNotConnected{} - } - v := e.receiver.RecvMaxQueueSize() - e.Unlock() - if v < 0 { - return -1, &tcpip.ErrQueueSizeNotSupported{} - } - return int(v), nil - default: log.Warningf("Unsupported socket option: %d", opt) return -1, &tcpip.ErrUnknownProtocolOption{} @@ -1029,3 +1012,15 @@ func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption { Max: maxBufferSize, } } + +// getReceiveBufferLimits implements tcpip.GetReceiveBufferLimits. +// +// We define min, max and default values for unix socket implementation. Unix +// sockets do not use receive buffer. +func getReceiveBufferLimits(tcpip.StackHandler) tcpip.ReceiveBufferSizeOption { + return tcpip.ReceiveBufferSizeOption{ + Min: minimumBufferSize, + Default: defaultBufferSize, + Max: maxBufferSize, + } +} diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index b22f7973a..db7b1affe 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -192,7 +193,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 7890d1048..c39e317ff 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" @@ -112,7 +113,7 @@ func (s *SocketVFS2) Release(ctx context.Context) { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen) } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 1b7fd2232..2ebd77f82 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/binary", "//pkg/bits", "//pkg/eventchannel", + "//pkg/hostarch", "//pkg/marshal/primitive", "//pkg/seccomp", "//pkg/sentry/arch", @@ -35,7 +36,6 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/netlink", "//pkg/sentry/syscalls/linux", - "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go index ae3b998c8..48650e3f9 100644 --- a/pkg/sentry/strace/epoll.go +++ b/pkg/sentry/strace/epoll.go @@ -21,10 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) -func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string { +func epollEvent(t *kernel.Task, eventAddr hostarch.Addr) string { var e linux.EpollEvent if _, err := e.CopyIn(t, eventAddr); err != nil { return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err) @@ -35,7 +36,7 @@ func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string { return sb.String() } -func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string { +func epollEvents(t *kernel.Task, eventsAddr hostarch.Addr, numEvents, maxBytes uint64) string { var sb strings.Builder fmt.Fprintf(&sb, "%#x {", eventsAddr) addr := eventsAddr diff --git a/pkg/sentry/strace/poll.go b/pkg/sentry/strace/poll.go index 074e80f9b..572a8b50b 100644 --- a/pkg/sentry/strace/poll.go +++ b/pkg/sentry/strace/poll.go @@ -22,7 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // PollEventSet is the set of poll(2) event flags. @@ -52,7 +53,7 @@ func pollFD(t *kernel.Task, pfd *linux.PollFD, post bool) string { return fmt.Sprintf("{FD: %s, Events: %s, REvents: %s}", fd(t, pfd.FD), PollEventSet.Parse(uint64(pfd.Events)), revents) } -func pollFDs(t *kernel.Task, addr usermem.Addr, nfds uint, post bool) string { +func pollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint, post bool) string { if addr == 0 { return "null" } diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go index 3a4c32aa0..e6e928157 100644 --- a/pkg/sentry/strace/select.go +++ b/pkg/sentry/strace/select.go @@ -19,7 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) func fdsFromSet(t *kernel.Task, set []byte) []int { @@ -35,7 +36,7 @@ func fdsFromSet(t *kernel.Task, set []byte) []int { return fds } -func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string { +func fdSet(t *kernel.Task, nfds int, addr hostarch.Addr) string { if nfds < 0 { return fmt.Sprintf("%#x (negative nfds)", addr) } diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go index c41f36e3f..e5b379a20 100644 --- a/pkg/sentry/strace/signal.go +++ b/pkg/sentry/strace/signal.go @@ -21,7 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // signalNames contains the names of all named signals. @@ -100,7 +101,7 @@ var sigActionFlags = abi.FlagSet{ }, } -func sigSet(t *kernel.Task, addr usermem.Addr) string { +func sigSet(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -110,7 +111,7 @@ func sigSet(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x (error copying sigset: %v)", addr, err) } - set := linux.SignalSet(usermem.ByteOrder.Uint64(b[:])) + set := linux.SignalSet(hostarch.ByteOrder.Uint64(b[:])) return fmt.Sprintf("%#x %s", addr, formatSigSet(set)) } @@ -124,7 +125,7 @@ func formatSigSet(set linux.SignalSet) string { return fmt.Sprintf("[%v]", strings.Join(signals, " ")) } -func sigAction(t *kernel.Task, addr usermem.Addr) string { +func sigAction(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index d943a7cb1..e5b7f9b96 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -26,7 +26,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // SocketFamily are the possible socket(2) families. @@ -161,7 +162,7 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } -func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) string { +func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) string { if length > maxBytes { return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length) } @@ -180,7 +181,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h) + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) var skipData bool level := "SOL_SOCKET" @@ -230,7 +231,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) numRights := rightsSize / linux.SizeOfControlMessageRight fds := make(linux.ControlMessageRights, numRights) - binary.Unmarshal(buf[i:i+rightsSize], usermem.ByteOrder, &fds) + binary.Unmarshal(buf[i:i+rightsSize], hostarch.ByteOrder, &fds) rights := make([]string, 0, len(fds)) for _, fd := range fds { @@ -257,7 +258,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", @@ -281,7 +282,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) } var tv linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &tv) + binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &tv) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", @@ -301,7 +302,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) } -func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string { +func msghdr(t *kernel.Task, addr hostarch.Addr, printContent bool, maxBytes uint64) string { var msg slinux.MessageHeader64 if _, err := msg.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err) @@ -311,17 +312,17 @@ func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint6 addr, msg.Name, msg.NameLen, - iovecs(t, usermem.Addr(msg.Iov), int(msg.IovLen), printContent, maxBytes), + iovecs(t, hostarch.Addr(msg.Iov), int(msg.IovLen), printContent, maxBytes), ) if printContent { - s = fmt.Sprintf("%s, control={%s}", s, cmsghdr(t, usermem.Addr(msg.Control), msg.ControlLen, maxBytes)) + s = fmt.Sprintf("%s, control={%s}", s, cmsghdr(t, hostarch.Addr(msg.Control), msg.ControlLen, maxBytes)) } else { s = fmt.Sprintf("%s, control=%#x, control_len=%d", s, msg.Control, msg.ControlLen) } return fmt.Sprintf("%s, flags=%d}", s, msg.Flags) } -func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { +func sockAddr(t *kernel.Task, addr hostarch.Addr, length uint32) string { if addr == 0 { return "null" } @@ -335,7 +336,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { if len(b) < 2 { return fmt.Sprintf("%#x {address too short: %d bytes}", addr, len(b)) } - family := usermem.ByteOrder.Uint16(b) + family := hostarch.ByteOrder.Uint16(b) familyStr := SocketFamily.Parse(uint64(family)) @@ -362,7 +363,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { } } -func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) string { +func postSockAddr(t *kernel.Task, addr hostarch.Addr, lengthPtr hostarch.Addr) string { if addr == 0 { return "null" } @@ -379,14 +380,14 @@ func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) str return sockAddr(t, addr, l) } -func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) { +func copySockLen(t *kernel.Task, addr hostarch.Addr) (uint32, error) { // socklen_t is 32-bits. var l primitive.Uint32 _, err := l.CopyIn(t, addr) return uint32(l), err } -func sockLenPointer(t *kernel.Task, addr usermem.Addr) string { +func sockLenPointer(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -420,7 +421,7 @@ func sockFlags(flags int32) string { return SocketFlagSet.Parse(uint64(flags)) } -func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen usermem.Addr, maximumBlobSize uint, rval uintptr) string { +func getSockOptVal(t *kernel.Task, level, optname uint64, optVal hostarch.Addr, optLen hostarch.Addr, maximumBlobSize uint, rval uintptr) string { if int(rval) < 0 { return hexNum(uint64(optVal)) } @@ -434,7 +435,7 @@ func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, o return sockOptVal(t, level, optname, optVal, uint64(l), maximumBlobSize) } -func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string { +func sockOptVal(t *kernel.Task, level, optname uint64, optVal hostarch.Addr, optLen uint64, maximumBlobSize uint) string { switch optLen { case 1: var v primitive.Uint8 diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 396744597..ec5d5f846 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -32,7 +32,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" pb "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // DefaultLogMaximumSize is the default LogMaximumSize. @@ -62,7 +63,7 @@ func hexArg(arg arch.SyscallArgument) string { return hexNum(arg.Uint64()) } -func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, maxBytes uint64) string { +func iovecs(t *kernel.Task, addr hostarch.Addr, iovcnt int, printContent bool, maxBytes uint64) string { if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV { return fmt.Sprintf("%#x (error decoding iovecs: invalid iovcnt)", addr) } @@ -107,7 +108,7 @@ func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, ma return fmt.Sprintf("%#x %s", addr, strings.Join(iovs, ", ")) } -func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) string { +func dump(t *kernel.Task, addr hostarch.Addr, size uint, maximumBlobSize uint) string { origSize := size if size > maximumBlobSize { size = maximumBlobSize @@ -131,7 +132,7 @@ func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) st return fmt.Sprintf("%#x %q%s", addr, b[:amt], dot) } -func path(t *kernel.Task, addr usermem.Addr) string { +func path(t *kernel.Task, addr hostarch.Addr) string { path, err := t.CopyInString(addr, linux.PATH_MAX) if err != nil { return fmt.Sprintf("%#x (error decoding path: %s)", addr, err) @@ -196,7 +197,7 @@ func fdVFS2(t *kernel.Task, fd int32) string { return fmt.Sprintf("%#x %s", fd, name) } -func fdpair(t *kernel.Task, addr usermem.Addr) string { +func fdpair(t *kernel.Task, addr hostarch.Addr) string { var fds [2]int32 _, err := primitive.CopyInt32SliceIn(t, addr, fds[:]) if err != nil { @@ -206,7 +207,7 @@ func fdpair(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x [%d %d]", addr, fds[0], fds[1]) } -func uname(t *kernel.Task, addr usermem.Addr) string { +func uname(t *kernel.Task, addr hostarch.Addr) string { var u linux.UtsName if _, err := u.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err) @@ -215,7 +216,7 @@ func uname(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x %s", addr, u) } -func utimensTimespec(t *kernel.Task, addr usermem.Addr) string { +func utimensTimespec(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -237,7 +238,7 @@ func utimensTimespec(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x {sec=%v nsec=%s}", addr, tim.Sec, ns) } -func timespec(t *kernel.Task, addr usermem.Addr) string { +func timespec(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -249,7 +250,7 @@ func timespec(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec) } -func timeval(t *kernel.Task, addr usermem.Addr) string { +func timeval(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -262,7 +263,7 @@ func timeval(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x {sec=%v usec=%v}", addr, tim.Sec, tim.Usec) } -func utimbuf(t *kernel.Task, addr usermem.Addr) string { +func utimbuf(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -275,7 +276,7 @@ func utimbuf(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x {actime=%v, modtime=%v}", addr, utim.Actime, utim.Modtime) } -func stat(t *kernel.Task, addr usermem.Addr) string { +func stat(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -287,27 +288,27 @@ func stat(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec)) } -func itimerval(t *kernel.Task, addr usermem.Addr) string { +func itimerval(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } interval := timeval(t, addr) - value := timeval(t, addr+usermem.Addr((*linux.Timeval)(nil).SizeBytes())) + value := timeval(t, addr+hostarch.Addr((*linux.Timeval)(nil).SizeBytes())) return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value) } -func itimerspec(t *kernel.Task, addr usermem.Addr) string { +func itimerspec(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } interval := timespec(t, addr) - value := timespec(t, addr+usermem.Addr((*linux.Timespec)(nil).SizeBytes())) + value := timespec(t, addr+hostarch.Addr((*linux.Timespec)(nil).SizeBytes())) return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value) } -func stringVector(t *kernel.Task, addr usermem.Addr) string { +func stringVector(t *kernel.Task, addr hostarch.Addr) string { vec, err := t.CopyInVector(addr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize) if err != nil { return fmt.Sprintf("%#x {error copying vector: %v}", addr, err) @@ -323,7 +324,7 @@ func stringVector(t *kernel.Task, addr usermem.Addr) string { return s } -func rusage(t *kernel.Task, addr usermem.Addr) string { +func rusage(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -335,7 +336,7 @@ func rusage(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x %+v", addr, ru) } -func capHeader(t *kernel.Task, addr usermem.Addr) string { +func capHeader(t *kernel.Task, addr hostarch.Addr) string { if addr == 0 { return "null" } @@ -360,7 +361,7 @@ func capHeader(t *kernel.Task, addr usermem.Addr) string { return fmt.Sprintf("%#x {Version: %s, Pid: %d}", addr, version, hdr.Pid) } -func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { +func capData(t *kernel.Task, hdrAddr, dataAddr hostarch.Addr) string { if dataAddr == 0 { return "null" } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 3dcf36a96..408a6c422 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -64,6 +64,7 @@ go_library( "//pkg/abi/linux", "//pkg/bpf", "//pkg/context", + "//pkg/hostarch", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index efec93f73..37121186a 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -33,6 +33,14 @@ var ( partialResultOnce sync.Once ) +// incrementPartialResultMetric increments PartialResultMetric by calling +// Increment(). This is added as the func Do() which is called below requires +// us to pass a function which does not take any arguments, whereas Increment() +// takes a variadic number of arguments. +func incrementPartialResultMetric() { + partialResultMetric.Increment() +} + // HandleIOErrorVFS2 handles special error cases for partial results. For some // errors, we may consume the error and return only the partial read/write. // @@ -48,7 +56,7 @@ func HandleIOErrorVFS2(ctx context.Context, partialResult bool, ioerr, intr erro root := vfs.RootFromContext(ctx) name, _ := fs.PathnameWithDeleted(ctx, root, f.VirtualDentry()) log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, ioerr, ioerr, op, name) - partialResultOnce.Do(partialResultMetric.Increment) + partialResultOnce.Do(incrementPartialResultMetric) } return nil } @@ -66,7 +74,7 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o // An unknown error is encountered with a partial read/write. name, _ := f.Dirent.FullName(nil /* ignore chroot */) log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, ioerr, ioerr, op, name, f.FileOperations) - partialResultOnce.Do(partialResultMetric.Increment) + partialResultOnce.Do(incrementPartialResultMetric) } return nil } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ac53a0c0e..2d2212605 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -18,11 +18,11 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/syscalls" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) const ( @@ -405,7 +405,7 @@ var AMD64 = &kernel.SyscallTable{ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), }, - Emulate: map[usermem.Addr]uintptr{ + Emulate: map[hostarch.Addr]uintptr{ 0xffffffffff600000: 96, // vsyscall gettimeofday(2) 0xffffffffff600400: 201, // vsyscall time(2) 0xffffffffff600800: 309, // vsyscall getcpu(2) @@ -723,7 +723,7 @@ var ARM64 = &kernel.SyscallTable{ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), }, - Emulate: map[usermem.Addr]uintptr{}, + Emulate: map[hostarch.Addr]uintptr{}, Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { t.Kernel().EmitUnimplementedEvent(t) return 0, syserror.ENOSYS diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go index 434559b80..e8c2d8f9e 100644 --- a/pkg/sentry/syscalls/linux/sigset.go +++ b/pkg/sentry/syscalls/linux/sigset.go @@ -16,9 +16,9 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // CopyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and @@ -27,7 +27,7 @@ import ( // TODO(gvisor.dev/issue/1624): This is only exported because // syscalls/vfs2/signal.go depends on it. Once vfs1 is deleted and the vfs2 // syscalls are moved into this package, then they can be unexported. -func CopyInSigSet(t *kernel.Task, sigSetAddr usermem.Addr, size uint) (linux.SignalSet, error) { +func CopyInSigSet(t *kernel.Task, sigSetAddr hostarch.Addr, size uint) (linux.SignalSet, error) { if size != linux.SignalSetSize { return 0, syserror.EINVAL } @@ -35,14 +35,14 @@ func CopyInSigSet(t *kernel.Task, sigSetAddr usermem.Addr, size uint) (linux.Sig if _, err := t.CopyInBytes(sigSetAddr, b); err != nil { return 0, err } - mask := usermem.ByteOrder.Uint64(b[:]) + mask := hostarch.ByteOrder.Uint64(b[:]) return linux.SignalSet(mask) &^ kernel.UnblockableSignals, nil } // copyOutSigSet copies out a sigset_t. -func copyOutSigSet(t *kernel.Task, sigSetAddr usermem.Addr, mask linux.SignalSet) error { +func copyOutSigSet(t *kernel.Task, sigSetAddr hostarch.Addr, mask linux.SignalSet) error { b := t.CopyScratchBuffer(8) - usermem.ByteOrder.PutUint64(b, uint64(mask)) + hostarch.ByteOrder.PutUint64(b, uint64(mask)) _, err := t.CopyOutBytes(sigSetAddr, b) return err } @@ -55,15 +55,15 @@ func copyOutSigSet(t *kernel.Task, sigSetAddr usermem.Addr, mask linux.SignalSet // }; // // and returns sigset_addr and size. -func copyInSigSetWithSize(t *kernel.Task, addr usermem.Addr) (usermem.Addr, uint, error) { +func copyInSigSetWithSize(t *kernel.Task, addr hostarch.Addr) (hostarch.Addr, uint, error) { switch t.Arch().Width() { case 8: in := t.CopyScratchBuffer(16) if _, err := t.CopyInBytes(addr, in); err != nil { return 0, 0, err } - maskAddr := usermem.Addr(usermem.ByteOrder.Uint64(in[0:])) - maskSize := uint(usermem.ByteOrder.Uint64(in[8:])) + maskAddr := hostarch.Addr(hostarch.ByteOrder.Uint64(in[0:])) + maskSize := uint(hostarch.ByteOrder.Uint64(in[8:])) return maskAddr, maskSize, nil default: return 0, 0, syserror.ENOSYS diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index c2285f796..70e8569a8 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -17,6 +17,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -152,7 +153,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S } // Keep rolling. - eventsAddr += usermem.Addr(linux.IOEventSize) + eventsAddr += hostarch.Addr(linux.IOEventSize) } // Everything finished. @@ -191,12 +192,12 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) // I/O. switch cb.OpCode { case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE: - return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{ + return t.SingleIOSequence(hostarch.Addr(cb.Buf), bytes, usermem.IOOpts{ AddressSpaceActive: false, }) case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV: - return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{ + return t.IovecsIOSequence(hostarch.Addr(cb.Buf), bytes, usermem.IOOpts{ AddressSpaceActive: false, }) @@ -219,7 +220,7 @@ func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // LINT.IfChange -func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, actx *mm.AIOContext, eventFile *fs.File) kernel.AIOCallback { +func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr hostarch.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, actx *mm.AIOContext, eventFile *fs.File) kernel.AIOCallback { return func(ctx context.Context) { if actx.Dead() { actx.CancelPendingRequest() @@ -264,7 +265,7 @@ func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linu } // submitCallback processes a single callback. -func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error { +func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr hostarch.Addr) error { file := t.GetFile(cb.FD) if file == nil { // File not found. @@ -339,7 +340,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc for i := int32(0); i < nrEvents; i++ { // Copy in the callback address. - var cbAddr usermem.Addr + var cbAddr hostarch.Addr switch t.Arch().Width() { case 8: var cbAddrP primitive.Uint64 @@ -351,7 +352,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Nothing done. return 0, nil, err } - cbAddr = usermem.Addr(cbAddrP) + cbAddr = hostarch.Addr(cbAddrP) default: return 0, nil, syserror.ENOSYS } @@ -379,7 +380,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } // Advance to the next one. - addr += usermem.Addr(t.Arch().Width()) + addr += hostarch.Addr(t.Arch().Width()) } return uintptr(nrEvents), nil, nil diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index fd9649340..9cd238efd 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -29,7 +30,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // fileOpAt performs an operation on the second last component in the path. @@ -115,7 +115,7 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro } // copyInPath copies a path in. -func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string, dirPath bool, err error) { +func copyInPath(t *kernel.Task, addr hostarch.Addr, allowEmpty bool) (path string, dirPath bool, err error) { path, err = t.CopyInString(addr, linux.PATH_MAX) if err != nil { return "", false, err @@ -133,7 +133,7 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string // LINT.IfChange -func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) { +func openAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint) (fd uintptr, err error) { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { return 0, err @@ -208,7 +208,7 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint return fd, err // Use result in frame. } -func mknodAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { +func mknodAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMode) error { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { return err @@ -301,7 +301,7 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, mknodAt(t, dirFD, path, mode) } -func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode linux.FileMode) (fd uintptr, err error) { +func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode linux.FileMode) (fd uintptr, err error) { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { return 0, err @@ -515,7 +515,7 @@ func (ac accessContext) Value(key interface{}) interface{} { } } -func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode uint) error { +func accessAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode uint) error { const rOK = 4 const wOK = 2 const xOK = 1 @@ -694,7 +694,7 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Top it off with a terminator. - _, err = t.CopyOutBytes(addr+usermem.Addr(bytes), []byte("\x00")) + _, err = t.CopyOutBytes(addr+hostarch.Addr(bytes), []byte("\x00")) return uintptr(bytes + 1), nil, err } @@ -1164,7 +1164,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // LINT.IfChange -func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { +func mkdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { return err @@ -1216,7 +1216,7 @@ func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, mkdirAt(t, dirFD, addr, mode) } -func rmdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { +func rmdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { return err @@ -1256,7 +1256,7 @@ func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, rmdirAt(t, linux.AT_FDCWD, addr) } -func symlinkAt(t *kernel.Task, dirFD int32, newAddr usermem.Addr, oldAddr usermem.Addr) error { +func symlinkAt(t *kernel.Task, dirFD int32, newAddr hostarch.Addr, oldAddr hostarch.Addr) error { newPath, dirPath, err := copyInPath(t, newAddr, false /* allowEmpty */) if err != nil { return err @@ -1341,7 +1341,7 @@ func mayLinkAt(t *kernel.Task, target *fs.Inode) error { // linkAt creates a hard link to the target specified by oldDirFD and oldAddr, // specified by newDirFD and newAddr. If resolve is true, then the symlinks // will be followed when evaluating the target. -func linkAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr, resolve, allowEmpty bool) error { +func linkAt(t *kernel.Task, oldDirFD int32, oldAddr hostarch.Addr, newDirFD int32, newAddr hostarch.Addr, resolve, allowEmpty bool) error { oldPath, _, err := copyInPath(t, oldAddr, allowEmpty) if err != nil { return err @@ -1448,7 +1448,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // LINT.IfChange -func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) { +func readlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, bufAddr hostarch.Addr, size uint) (copied uintptr, err error) { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { return 0, err @@ -1511,7 +1511,7 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // LINT.IfChange -func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { +func unlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr) error { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { return err @@ -1728,7 +1728,7 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { return nil } -func chownAt(t *kernel.Task, fd int32, addr usermem.Addr, resolve, allowEmpty bool, uid auth.UID, gid auth.GID) error { +func chownAt(t *kernel.Task, fd int32, addr hostarch.Addr, resolve, allowEmpty bool, uid auth.UID, gid auth.GID) error { path, _, err := copyInPath(t, addr, allowEmpty) if err != nil { return err @@ -1815,7 +1815,7 @@ func chmod(t *kernel.Task, d *fs.Dirent, mode linux.FileMode) error { return nil } -func chmodAt(t *kernel.Task, fd int32, addr usermem.Addr, mode linux.FileMode) error { +func chmodAt(t *kernel.Task, fd int32, addr hostarch.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { return err @@ -1866,7 +1866,7 @@ func defaultSetToSystemTimeSpec() fs.TimeSpec { } } -func utimes(t *kernel.Task, dirFD int32, addr usermem.Addr, ts fs.TimeSpec, resolve bool) error { +func utimes(t *kernel.Task, dirFD int32, addr hostarch.Addr, ts fs.TimeSpec, resolve bool) error { setTimestamp := func(root *fs.Dirent, d *fs.Dirent, _ uint) error { // Does the task own the file? if !d.Inode.CheckOwnership(t) { @@ -2030,7 +2030,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // LINT.IfChange -func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error { +func renameAt(t *kernel.Task, oldDirFD int32, oldAddr hostarch.Addr, newDirFD int32, newAddr hostarch.Addr) error { newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */) if err != nil { return err diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index f39ce0639..eeea1613b 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -18,11 +18,11 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // futexWaitRestartBlock encapsulates the state required to restart futex(2) @@ -41,7 +41,7 @@ type futexWaitRestartBlock struct { // Restart implements kernel.SyscallRestartBlock.Restart. func (f *futexWaitRestartBlock) Restart(t *kernel.Task) (uintptr, error) { - return futexWaitDuration(t, f.duration, false, usermem.Addr(f.addr), f.private, f.val, f.mask) + return futexWaitDuration(t, f.duration, false, hostarch.Addr(f.addr), f.private, f.val, f.mask) } // futexWaitAbsolute performs a FUTEX_WAIT_BITSET, blocking until the wait is @@ -51,7 +51,7 @@ func (f *futexWaitRestartBlock) Restart(t *kernel.Task) (uintptr, error) { // // If blocking is interrupted, the syscall is restarted with the original // arguments. -func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, forever bool, addr usermem.Addr, private bool, val, mask uint32) (uintptr, error) { +func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, forever bool, addr hostarch.Addr, private bool, val, mask uint32) (uintptr, error) { w := t.FutexWaiter() err := t.Futex().WaitPrepare(w, t, addr, private, val, mask) if err != nil { @@ -87,7 +87,7 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo // syscall. If forever is true, the syscall is restarted with the original // arguments. If forever is false, duration is a relative timeout and the // syscall is restarted with the remaining timeout. -func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, addr usermem.Addr, private bool, val, mask uint32) (uintptr, error) { +func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, addr hostarch.Addr, private bool, val, mask uint32) (uintptr, error) { w := t.FutexWaiter() err := t.Futex().WaitPrepare(w, t, addr, private, val, mask) if err != nil { @@ -124,7 +124,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add return 0, syserror.ERESTART_RESTARTBLOCK } -func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.Addr, private bool) error { +func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr hostarch.Addr, private bool) error { w := t.FutexWaiter() locked, err := t.Futex().LockPI(w, t, addr, uint32(t.ThreadID()), private, false) if err != nil { @@ -152,7 +152,7 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.A return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } -func tryLockPI(t *kernel.Task, addr usermem.Addr, private bool) error { +func tryLockPI(t *kernel.Task, addr hostarch.Addr, private bool) error { w := t.FutexWaiter() locked, err := t.Futex().LockPI(w, t, addr, uint32(t.ThreadID()), private, true) if err != nil { diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index b25f7d881..bbba71d8f 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -19,6 +19,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -62,7 +63,7 @@ func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getdents implements the core of getdents(2)/getdents64(2). // f is the syscall implementation dirent serialization function. -func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dirent, io.Writer) (int, error)) (uintptr, error) { +func getdents(t *kernel.Task, fd int32, addr hostarch.Addr, size int, f func(*dirent, io.Writer) (int, error)) (uintptr, error) { dir := t.GetFile(fd) if dir == nil { return 0, syserror.EBADF diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go index 9b4a5c3f1..6d27f4292 100644 --- a/pkg/sentry/syscalls/linux/sys_mempolicy.go +++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -31,7 +32,7 @@ const ( allowedNodemask = (1 << maxNodes) - 1 ) -func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, error) { +func copyInNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32) (uint64, error) { // "nodemask points to a bit mask of node IDs that contains up to maxnode // bits. The bit mask size is rounded to the next multiple of // sizeof(unsigned long), but the kernel will use bits only up to maxnode. @@ -41,7 +42,7 @@ func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, // because of what appears to be a bug: mm/mempolicy.c:get_nodes() uses // maxnode-1, not maxnode, as the number of bits. bits := maxnode - 1 - if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0 + if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0 return 0, syserror.EINVAL } if bits == 0 { @@ -53,7 +54,7 @@ func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, if _, err := t.CopyInBytes(addr, buf); err != nil { return 0, err } - val := usermem.ByteOrder.Uint64(buf) + val := hostarch.ByteOrder.Uint64(buf) // Check that only allowed bits in the first unsigned long in the nodemask // are set. if val&^allowedNodemask != 0 { @@ -68,11 +69,11 @@ func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, return val, nil } -func copyOutNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32, val uint64) error { +func copyOutNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32, val uint64) error { // mm/mempolicy.c:copy_nodes_to_user() also uses maxnode-1 as the number of // bits. bits := maxnode - 1 - if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0 + if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0 return syserror.EINVAL } if bits == 0 { @@ -80,7 +81,7 @@ func copyOutNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32, val uint } // Copy out the first unsigned long in the nodemask. buf := t.CopyScratchBuffer(8) - usermem.ByteOrder.PutUint64(buf, val) + hostarch.ByteOrder.PutUint64(buf, val) if _, err := t.CopyOutBytes(addr, buf); err != nil { return err } @@ -258,7 +259,7 @@ func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err } -func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask usermem.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) { +func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask hostarch.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) { flags := linux.NumaPolicy(modeWithFlags & linux.MPOL_MODE_FLAGS) mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS) if flags == linux.MPOL_MODE_FLAGS { diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index cd8dfdfa4..70da0707d 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -23,7 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Brk implements linux syscall brk(2). @@ -61,12 +62,12 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Unmap: fixed, Map32Bit: map32bit, Private: private, - Perms: usermem.AccessType{ + Perms: hostarch.AccessType{ Read: linux.PROT_READ&prot != 0, Write: linux.PROT_WRITE&prot != 0, Execute: linux.PROT_EXEC&prot != 0, }, - MaxPerms: usermem.AnyAccess, + MaxPerms: hostarch.AnyAccess, GrowsDown: linux.MAP_GROWSDOWN&flags != 0, Precommit: linux.MAP_POPULATE&flags != 0, } @@ -160,7 +161,7 @@ func Mremap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal func Mprotect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { length := args[1].Uint64() prot := args[2].Int() - err := t.MemoryManager().MProtect(args[0].Pointer(), length, usermem.AccessType{ + err := t.MemoryManager().MProtect(args[0].Pointer(), length, hostarch.AccessType{ Read: linux.PROT_READ&prot != 0, Write: linux.PROT_WRITE&prot != 0, Execute: linux.PROT_EXEC&prot != 0, @@ -183,7 +184,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, nil } // Not explicitly stated: length need not be page-aligned. - lenAddr, ok := usermem.Addr(length).RoundUp() + lenAddr, ok := hostarch.Addr(length).RoundUp() if !ok { return 0, nil, syserror.EINVAL } @@ -232,7 +233,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // "The length argument need not be a multiple of the page size, but since // residency information is returned for whole pages, length is effectively // rounded up to the next multiple of the page size." - mincore(2) - la, ok := usermem.Addr(length).RoundUp() + la, ok := hostarch.Addr(length).RoundUp() if !ok { return 0, nil, syserror.ENOMEM } @@ -247,7 +248,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if mapped != uint64(la) { return 0, nil, syserror.ENOMEM } - resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize)) + resident := bytes.Repeat([]byte{1}, int(mapped/hostarch.PageSize)) _, err := t.CopyOutBytes(vec, resident) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go index bd0633564..864d2138c 100644 --- a/pkg/sentry/syscalls/linux/sys_mount.go +++ b/pkg/sentry/syscalls/linux/sys_mount.go @@ -20,7 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Mount implements Linux syscall mount(2). @@ -31,7 +32,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall flags := args[3].Uint64() dataAddr := args[4].Pointer() - fsType, err := t.CopyInString(typeAddr, usermem.PageSize) + fsType, err := t.CopyInString(typeAddr, hostarch.PageSize) if err != nil { return 0, nil, err } @@ -52,7 +53,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // character placement, and the address is passed to each file system. // Most file systems always treat this data as a string, though, and so // do all of the ones we implement. - data, err = t.CopyInString(dataAddr, usermem.PageSize) + data, err = t.CopyInString(dataAddr, hostarch.PageSize) if err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index f7135ea46..d95034347 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -16,19 +16,19 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange // pipe2 implements the actual system call with flags. -func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { +func pipe2(t *kernel.Task, addr hostarch.Addr, flags uint) (uintptr, error) { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return 0, syserror.EINVAL } diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index 254f4c9f9..da548a14a 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -18,13 +18,13 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -155,7 +155,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. } // CopyInPollFDs copies an array of struct pollfd unless nfds exceeds the max. -func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) { +func CopyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) { if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) { return nil, syserror.EINVAL } @@ -170,7 +170,7 @@ func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD return pfd, nil } -func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) { +func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) { pfd, err := CopyInPollFDs(t, addr, nfds) if err != nil { return timeout, 0, err @@ -198,7 +198,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) } // CopyInFDSet copies an fd set from select(2)/pselect(2). -func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { +func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { set := make([]byte, nBytes) if addr != 0 { @@ -215,7 +215,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy return set, nil } -func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) { +func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) { if nfds < 0 || nfds > fileCap { return 0, syserror.EINVAL } @@ -365,7 +365,7 @@ func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) // copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr. // // startNs must be from CLOCK_MONOTONIC. -func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error { +func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr hostarch.Addr) error { if timeout <= 0 { return nil } @@ -377,7 +377,7 @@ func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.D // copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr. // // startNs must be from CLOCK_MONOTONIC. -func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error { +func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr hostarch.Addr) error { if timeout <= 0 { return nil } @@ -391,7 +391,7 @@ func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Du // // +stateify savable type pollRestartBlock struct { - pfdAddr usermem.Addr + pfdAddr hostarch.Addr nfds uint timeout time.Duration } @@ -401,7 +401,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) { return poll(t, p.pfdAddr, p.nfds, p.timeout) } -func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) { +func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) { remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout) // On an interrupt poll(2) is restarted with the remaining timeout. if err == syserror.EINTR { diff --git a/pkg/sentry/syscalls/linux/sys_random.go b/pkg/sentry/syscalls/linux/sys_random.go index c0aa0fd60..ae545f80f 100644 --- a/pkg/sentry/syscalls/linux/sys_random.go +++ b/pkg/sentry/syscalls/linux/sys_random.go @@ -24,6 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -64,7 +66,7 @@ func GetRandom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if min > 256 { min = 256 } - n, err := t.MemoryManager().CopyOutFrom(t, usermem.AddrRangeSeqOf(ar), safemem.FromIOReader{&randReader{-1, min}}, usermem.IOOpts{ + n, err := t.MemoryManager().CopyOutFrom(t, hostarch.AddrRangeSeqOf(ar), safemem.FromIOReader{&randReader{-1, min}}, usermem.IOOpts{ AddressSpaceActive: true, }) if n >= int64(min) { diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index 88cd234d1..e64246d57 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -16,12 +16,12 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // rlimit describes an implementation of 'struct rlimit', which may vary from @@ -67,12 +67,12 @@ func (r *rlimit64) fromLimit(lim limits.Limit) { } } -func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error { +func (r *rlimit64) copyIn(t *kernel.Task, addr hostarch.Addr) error { _, err := r.CopyIn(t, addr) return err } -func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error { +func (r *rlimit64) copyOut(t *kernel.Task, addr hostarch.Addr) error { _, err := r.CopyOut(t, addr) return err } diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go index 4fdb4463c..e16d6ff3f 100644 --- a/pkg/sentry/syscalls/linux/sys_seccomp.go +++ b/pkg/sentry/syscalls/linux/sys_seccomp.go @@ -17,10 +17,10 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // userSockFprog is equivalent to Linux's struct sock_fprog on amd64. @@ -33,14 +33,14 @@ type userSockFprog struct { _ [6]byte // padding for alignment // Filter is a user pointer to the struct sock_filter array that makes up - // the filter program. Filter is a uint64 rather than a usermem.Addr - // because usermem.Addr is actually uintptr, which is not a fixed-size + // the filter program. Filter is a uint64 rather than a hostarch.Addr + // because hostarch.Addr is actually uintptr, which is not a fixed-size // type. Filter uint64 } // seccomp applies a seccomp policy to the current task. -func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error { +func seccomp(t *kernel.Task, mode, flags uint64, addr hostarch.Addr) error { // We only support SECCOMP_SET_MODE_FILTER at the moment. if mode != linux.SECCOMP_SET_MODE_FILTER { // Unsupported mode. @@ -60,7 +60,7 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error { return err } filter := make([]linux.BPFInstruction, int(fprog.Len)) - if _, err := linux.CopyBPFInstructionSliceIn(t, usermem.Addr(fprog.Filter), filter); err != nil { + if _, err := linux.CopyBPFInstructionSliceIn(t, hostarch.Addr(fprog.Filter), filter); err != nil { return err } compiledFilter, err := bpf.Compile(filter) diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index f0570d927..c84260080 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -19,13 +19,13 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) const opsMax = 500 // SEMOPM @@ -310,7 +310,7 @@ func setVal(t *kernel.Task, id int32, num int32, val int16) error { return set.SetVal(t, num, val, creds, int32(pid)) } -func setValAll(t *kernel.Task, id int32, array usermem.Addr) error { +func setValAll(t *kernel.Task, id int32, array hostarch.Addr) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { @@ -335,7 +335,7 @@ func getVal(t *kernel.Task, id int32, num int32) (int16, error) { return set.GetVal(num, creds) } -func getValAll(t *kernel.Task, id int32, array usermem.Addr) error { +func getValAll(t *kernel.Task, id int32, array hostarch.Addr) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index d639c9bf7..53b12dc41 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -19,12 +19,12 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // "For a process to have permission to send a signal it must @@ -516,7 +516,7 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } // sharedSignalfd is shared between the two calls. -func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) { +func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) { // Copy in the signal mask. mask, err := CopyInSigSet(t, sigset, sigsetsize) if err != nil { diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index c6adfe06b..eff251cec 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -34,12 +35,6 @@ import ( // LINT.IfChange -// minListenBacklog is the minimum reasonable backlog for listening sockets. -const minListenBacklog = 8 - -// maxListenBacklog is the maximum allowed backlog for listening sockets. -const maxListenBacklog = 1024 - // maxAddrLen is the maximum socket address length we're willing to accept. const maxAddrLen = 200 @@ -51,6 +46,9 @@ const maxOptLen = 1024 * 8 // buffers upto INT_MAX. const maxControlLen = 10 * 1024 * 1024 +// maxListenBacklog is the maximum limit of listen backlog supported. +const maxListenBacklog = 1024 + // nameLenOffset is the offset from the start of the MessageHeader64 struct to // the NameLen field. const nameLenOffset = 8 @@ -117,7 +115,7 @@ type multipleMessageHeader64 struct { // CaptureAddress allocates memory for and copies a socket address structure // from the untrusted address space range. -func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { +func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) { if addrlen > maxAddrLen { return nil, syserror.EINVAL } @@ -133,7 +131,7 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, // writeAddress writes a sockaddr structure and its length to an output buffer // in the unstrusted address space range. If the address is bigger than the // buffer, it is truncated. -func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { +func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr hostarch.Addr, addrLenPtr hostarch.Addr) error { // Get the buffer length. var bufLen uint32 if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil { @@ -276,7 +274,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // accept is the implementation of the accept syscall. It is called by accept // and accept4 syscall handlers. -func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) { +func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) { // Check that no unsupported flags are passed in. if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { return 0, syserror.EINVAL @@ -366,7 +364,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Listen implements the linux syscall listen(2). func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() - backlog := args[1].Int() + backlog := args[1].Uint() // Get socket from the file descriptor. file := t.GetFile(fd) @@ -381,11 +379,13 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ENOTSOCK } - // Per Linux, the backlog is silently capped to reasonable values. - if backlog <= 0 { - backlog = minListenBacklog - } if backlog > maxListenBacklog { + // Linux treats incoming backlog as uint with a limit defined by + // sysctl_somaxconn. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666 + // + // We use the backlog to allocate a channel of that size, hence enforce + // a hard limit for the backlog. backlog = maxListenBacklog } @@ -472,7 +472,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr hostarch.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -735,7 +735,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return uintptr(count), nil, nil } -func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { +func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { // Capture the message header and io vectors. var msg MessageHeader64 if _, err := msg.CopyIn(t, msgPtr); err != nil { @@ -745,7 +745,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if msg.IovLen > linux.UIO_MAXIOV { return 0, syserror.EMSGSIZE } - dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ + dst, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ AddressSpaceActive: true, }) if err != nil { @@ -796,7 +796,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i // Copy the address to the caller. if msg.NameLen != 0 { - if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil { + if err := writeAddress(t, sender, senderLen, hostarch.Addr(msg.Name), hostarch.Addr(msgPtr+nameLenOffset)); err != nil { return 0, err } } @@ -806,7 +806,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i return 0, err } if len(controlData) > 0 { - if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyOutBytes(hostarch.Addr(msg.Control), controlData); err != nil { return 0, err } } @@ -821,7 +821,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i // recvFrom is the implementation of the recvfrom syscall. It is called by // recvfrom and recv syscall handlers. -func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) { +func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) { if int(bufLen) < 0 { return 0, syserror.EINVAL } @@ -997,7 +997,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return uintptr(count), nil, nil } -func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) { +func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr hostarch.Addr, flags int32) (uintptr, error) { // Capture the message header. var msg MessageHeader64 if _, err := msg.CopyIn(t, msgPtr); err != nil { @@ -1011,7 +1011,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme return 0, syserror.ENOBUFS } controlData = make([]byte, msg.ControlLen) - if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyInBytes(hostarch.Addr(msg.Control), controlData); err != nil { return 0, err } } @@ -1020,7 +1020,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme var to []byte if msg.NameLen != 0 { var err error - to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen) + to, err = CaptureAddress(t, hostarch.Addr(msg.Name), msg.NameLen) if err != nil { return 0, err } @@ -1030,7 +1030,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme if msg.IovLen > linux.UIO_MAXIOV { return 0, syserror.EMSGSIZE } - src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ + src, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ AddressSpaceActive: true, }) if err != nil { @@ -1064,7 +1064,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme // sendTo is the implementation of the sendto syscall. It is called by sendto // and send syscall handlers. -func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) { +func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) { bl := int(bufLen) if bl < 0 { return 0, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index cda29a8b5..2338ba44b 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -16,11 +16,11 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange @@ -106,7 +106,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } // stat implements stat from the given *fs.Dirent. -func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) error { +func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr hostarch.Addr) error { if dirPath && !fs.IsDir(d.Inode.StableAttr) { return syserror.ENOTDIR } @@ -120,7 +120,7 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err } // fstat implements fstat for the given *fs.File. -func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error { +func fstat(t *kernel.Task, f *fs.File, statAddr hostarch.Addr) error { uattr, err := f.UnstableAttr(t) if err != nil { return err @@ -180,7 +180,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall }) } -func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr usermem.Addr) error { +func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr hostarch.Addr) error { // "[T]he kernel may return fields that weren't requested and may fail to // return fields that were requested, depending on what the backing // filesystem supports. @@ -257,7 +257,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // statfsImpl implements the linux syscall statfs and fstatfs based on a Dirent, // copying the statfs structure out to addr on success, otherwise an error is // returned. -func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error { +func statfsImpl(t *kernel.Task, d *fs.Dirent, addr hostarch.Addr) error { info, err := d.Inode.StatFS(t) if err != nil { return err diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index b5f920949..3185ea527 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -19,6 +19,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -46,7 +47,7 @@ var ( ExecMaxTotalSize = 2 * 1024 * 1024 // ExecMaxElemSize is the maximum length of a single argv or envv entry. - ExecMaxElemSize = 32 * usermem.PageSize + ExecMaxElemSize = 32 * hostarch.PageSize ) // Getppid implements linux syscall getppid(2). @@ -88,7 +89,7 @@ func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags) } -func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { +func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr hostarch.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX) if err != nil { return 0, nil, err @@ -199,7 +200,7 @@ func ExitGroup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } // clone is used by Clone, Fork, and VFork. -func clone(t *kernel.Task, flags int, stack usermem.Addr, parentTID usermem.Addr, childTID usermem.Addr, tls usermem.Addr) (uintptr, *kernel.SyscallControl, error) { +func clone(t *kernel.Task, flags int, stack hostarch.Addr, parentTID hostarch.Addr, childTID hostarch.Addr, tls hostarch.Addr) (uintptr, *kernel.SyscallControl, error) { opts := kernel.CloneOptions{ SharingOptions: kernel.SharingOptions{ NewAddressSpace: flags&linux.CLONE_VM == 0, @@ -274,7 +275,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error { } // wait4 waits for the given child process to exit. -func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusageAddr usermem.Addr) (uintptr, error) { +func wait4(t *kernel.Task, pid int, statusAddr hostarch.Addr, options int, rusageAddr hostarch.Addr) (uintptr, error) { if options&^(linux.WNOHANG|linux.WUNTRACED|linux.WCONTINUED|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 { return 0, syserror.EINVAL } diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index c5054d2f1..83b777bbd 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -19,12 +19,12 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // The most significant 29 bits hold either a pid or a file descriptor. @@ -165,7 +165,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC addr := args[0].Pointer() r := t.Kernel().RealtimeClock().Now().TimeT() - if addr == usermem.Addr(0) { + if addr == hostarch.Addr(0) { return uintptr(r), nil, nil } @@ -182,7 +182,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC type clockNanosleepRestartBlock struct { c ktime.Clock duration time.Duration - rem usermem.Addr + rem hostarch.Addr } // Restart implements kernel.SyscallRestartBlock.Restart. @@ -221,7 +221,7 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error // // If blocking is interrupted, the syscall is restarted with the remaining // duration timeout. -func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem usermem.Addr) error { +func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem hostarch.Addr) error { timer, start, tchan := ktime.After(c, dur) err := t.BlockWithTimer(nil, tchan) @@ -324,14 +324,14 @@ func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. tv := args[0].Pointer() tz := args[1].Pointer() - if tv != usermem.Addr(0) { + if tv != hostarch.Addr(0) { nowTv := t.Kernel().RealtimeClock().Now().Timeval() if err := copyTimevalOut(t, tv, &nowTv); err != nil { return 0, nil, err } } - if tz != usermem.Addr(0) { + if tz != hostarch.Addr(0) { // Ask the time package for the timezone. _, offset := time.Now().Zone() // This int32 array mimics linux's struct timezone. diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 97474fd3c..28ad6a60e 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -18,11 +18,11 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange @@ -87,7 +87,7 @@ func getXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink } // getXattr implements getxattr(2) from the given *fs.Dirent. -func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64) (int, error) { +func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, size uint64) (int, error) { name, err := copyInXattrName(t, nameAddr) if err != nil { return 0, err @@ -180,7 +180,7 @@ func setXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink } // setXattr implements setxattr(2) from the given *fs.Dirent. -func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64, flags uint32) error { +func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, size uint64, flags uint32) error { if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { return syserror.EINVAL } @@ -214,7 +214,7 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, si return nil } -func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { +func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) { name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) if err != nil { if err == syserror.ENAMETOOLONG { @@ -306,7 +306,7 @@ func listXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlin return uintptr(n), nil, nil } -func listXattr(t *kernel.Task, d *fs.Dirent, addr usermem.Addr, size uint64) (int, error) { +func listXattr(t *kernel.Task, d *fs.Dirent, addr hostarch.Addr, size uint64) (int, error) { if !xattrFileTypeOk(d.Inode) { return 0, nil } @@ -408,7 +408,7 @@ func removeXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSyml } // removeXattr implements removexattr(2) from the given *fs.Dirent. -func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error { +func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr hostarch.Addr) error { name, err := copyInXattrName(t, nameAddr) if err != nil { return err diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go index ddc3ee26e..3edc922eb 100644 --- a/pkg/sentry/syscalls/linux/timespec.go +++ b/pkg/sentry/syscalls/linux/timespec.go @@ -18,13 +18,13 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // copyTimespecIn copies a Timespec from the untrusted app range to the kernel. -func copyTimespecIn(t *kernel.Task, addr usermem.Addr) (linux.Timespec, error) { +func copyTimespecIn(t *kernel.Task, addr hostarch.Addr) (linux.Timespec, error) { switch t.Arch().Width() { case 8: ts := linux.Timespec{} @@ -33,8 +33,8 @@ func copyTimespecIn(t *kernel.Task, addr usermem.Addr) (linux.Timespec, error) { if err != nil { return ts, err } - ts.Sec = int64(usermem.ByteOrder.Uint64(in[0:])) - ts.Nsec = int64(usermem.ByteOrder.Uint64(in[8:])) + ts.Sec = int64(hostarch.ByteOrder.Uint64(in[0:])) + ts.Nsec = int64(hostarch.ByteOrder.Uint64(in[8:])) return ts, nil default: return linux.Timespec{}, syserror.ENOSYS @@ -42,12 +42,12 @@ func copyTimespecIn(t *kernel.Task, addr usermem.Addr) (linux.Timespec, error) { } // copyTimespecOut copies a Timespec to the untrusted app range. -func copyTimespecOut(t *kernel.Task, addr usermem.Addr, ts *linux.Timespec) error { +func copyTimespecOut(t *kernel.Task, addr hostarch.Addr, ts *linux.Timespec) error { switch t.Arch().Width() { case 8: out := t.CopyScratchBuffer(16) - usermem.ByteOrder.PutUint64(out[0:], uint64(ts.Sec)) - usermem.ByteOrder.PutUint64(out[8:], uint64(ts.Nsec)) + hostarch.ByteOrder.PutUint64(out[0:], uint64(ts.Sec)) + hostarch.ByteOrder.PutUint64(out[8:], uint64(ts.Nsec)) _, err := t.CopyOutBytes(addr, out) return err default: @@ -56,7 +56,7 @@ func copyTimespecOut(t *kernel.Task, addr usermem.Addr, ts *linux.Timespec) erro } // copyTimevalIn copies a Timeval from the untrusted app range to the kernel. -func copyTimevalIn(t *kernel.Task, addr usermem.Addr) (linux.Timeval, error) { +func copyTimevalIn(t *kernel.Task, addr hostarch.Addr) (linux.Timeval, error) { switch t.Arch().Width() { case 8: tv := linux.Timeval{} @@ -65,8 +65,8 @@ func copyTimevalIn(t *kernel.Task, addr usermem.Addr) (linux.Timeval, error) { if err != nil { return tv, err } - tv.Sec = int64(usermem.ByteOrder.Uint64(in[0:])) - tv.Usec = int64(usermem.ByteOrder.Uint64(in[8:])) + tv.Sec = int64(hostarch.ByteOrder.Uint64(in[0:])) + tv.Usec = int64(hostarch.ByteOrder.Uint64(in[8:])) return tv, nil default: return linux.Timeval{}, syserror.ENOSYS @@ -74,12 +74,12 @@ func copyTimevalIn(t *kernel.Task, addr usermem.Addr) (linux.Timeval, error) { } // copyTimevalOut copies a Timeval to the untrusted app range. -func copyTimevalOut(t *kernel.Task, addr usermem.Addr, tv *linux.Timeval) error { +func copyTimevalOut(t *kernel.Task, addr hostarch.Addr, tv *linux.Timeval) error { switch t.Arch().Width() { case 8: out := t.CopyScratchBuffer(16) - usermem.ByteOrder.PutUint64(out[0:], uint64(tv.Sec)) - usermem.ByteOrder.PutUint64(out[8:], uint64(tv.Usec)) + hostarch.ByteOrder.PutUint64(out[0:], uint64(tv.Sec)) + hostarch.ByteOrder.PutUint64(out[8:], uint64(tv.Usec)) _, err := t.CopyOutBytes(addr, out) return err default: @@ -94,7 +94,7 @@ func copyTimevalOut(t *kernel.Task, addr usermem.Addr, tv *linux.Timeval) error // returned value is the maximum that Duration will allow. // // If timespecAddr is NULL, the returned value is negative. -func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) { +func copyTimespecInToDuration(t *kernel.Task, timespecAddr hostarch.Addr) (time.Duration, error) { // Use a negative Duration to indicate "no timeout". timeout := time.Duration(-1) if timespecAddr != 0 { diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 2e59bd5b1..5ce0bc714 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -43,6 +43,7 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/gohacks", + "//pkg/hostarch", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index de6789a65..fd1863ef3 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -26,6 +26,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // IoSubmit implements linux syscall io_submit(2). @@ -40,7 +42,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc for i := int32(0); i < nrEvents; i++ { // Copy in the callback address. - var cbAddr usermem.Addr + var cbAddr hostarch.Addr switch t.Arch().Width() { case 8: var cbAddrP primitive.Uint64 @@ -52,7 +54,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Nothing done. return 0, nil, err } - cbAddr = usermem.Addr(cbAddrP) + cbAddr = hostarch.Addr(cbAddrP) default: return 0, nil, syserror.ENOSYS } @@ -79,14 +81,14 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } // Advance to the next one. - addr += usermem.Addr(t.Arch().Width()) + addr += hostarch.Addr(t.Arch().Width()) } return uintptr(nrEvents), nil, nil } // submitCallback processes a single callback. -func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error { +func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr hostarch.Addr) error { if cb.Reserved2 != 0 { return syserror.EINVAL } @@ -148,7 +150,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user return nil } -func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback { +func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr hostarch.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback { return func(ctx context.Context) { // Release references after completing the callback. defer fd.DecRef(ctx) @@ -206,12 +208,12 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) // I/O. switch cb.OpCode { case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE: - return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{ + return t.SingleIOSequence(hostarch.Addr(cb.Buf), bytes, usermem.IOOpts{ AddressSpaceActive: false, }) case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV: - return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{ + return t.IovecsIOSequence(hostarch.Addr(cb.Buf), bytes, usermem.IOOpts{ AddressSpaceActive: false, }) diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go index 7a409620d..3315398a4 100644 --- a/pkg/sentry/syscalls/linux/vfs2/execve.go +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -24,7 +24,8 @@ import ( slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Execve implements linux syscall execve(2). @@ -45,7 +46,7 @@ func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags) } -func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { +func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr hostarch.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { return 0, nil, syserror.EINVAL } diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index 01e0f9010..36aa1d3ae 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -20,7 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Link implements Linux syscall link(2). @@ -40,7 +41,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags) } -func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error { +func linkat(t *kernel.Task, olddirfd int32, oldpathAddr hostarch.Addr, newdirfd int32, newpathAddr hostarch.Addr, flags int32) error { if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 { return syserror.EINVAL } @@ -86,7 +87,7 @@ func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, mkdirat(t, dirfd, addr, mode) } -func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error { +func mkdirat(t *kernel.Task, dirfd int32, addr hostarch.Addr, mode uint) error { path, err := copyInPath(t, addr) if err != nil { return err @@ -118,7 +119,7 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev) } -func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error { +func mknodat(t *kernel.Task, dirfd int32, addr hostarch.Addr, mode linux.FileMode, dev uint32) error { path, err := copyInPath(t, addr) if err != nil { return err @@ -165,7 +166,7 @@ func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode) } -func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) { +func openat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) { path, err := copyInPath(t, pathAddr) if err != nil { return 0, nil, err @@ -217,7 +218,7 @@ func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags) } -func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error { +func renameat(t *kernel.Task, olddirfd int32, oldpathAddr hostarch.Addr, newdirfd int32, newpathAddr hostarch.Addr, flags uint32) error { oldpath, err := copyInPath(t, oldpathAddr) if err != nil { return err @@ -250,7 +251,7 @@ func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr) } -func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { +func rmdirat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr) error { path, err := copyInPath(t, pathAddr) if err != nil { return err @@ -269,7 +270,7 @@ func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr) } -func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { +func unlinkat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr) error { path, err := copyInPath(t, pathAddr) if err != nil { return err @@ -313,7 +314,7 @@ func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr) } -func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error { +func symlinkat(t *kernel.Task, targetAddr hostarch.Addr, newdirfd int32, linkpathAddr hostarch.Addr) error { target, err := t.CopyInString(targetAddr, linux.PATH_MAX) if err != nil { return err diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go index 5517595b5..b41a3056a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/getdents.go +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go @@ -22,7 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Getdents implements Linux syscall getdents(2). @@ -58,7 +59,7 @@ func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (ui type getdentsCallback struct { t *kernel.Task - addr usermem.Addr + addr hostarch.Addr remaining int isGetdents64 bool } @@ -69,7 +70,7 @@ var getdentsCallbackPool = sync.Pool{ }, } -func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback { +func getGetdentsCallback(t *kernel.Task, addr hostarch.Addr, size int, isGetdents64 bool) *getdentsCallback { cb := getdentsCallbackPool.Get().(*getdentsCallback) *cb = getdentsCallback{ t: t, @@ -102,9 +103,9 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { return syserror.EINVAL } buf = cb.t.CopyScratchBuffer(size) - usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino) - usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) - usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) + hostarch.ByteOrder.PutUint64(buf[0:8], dirent.Ino) + hostarch.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) + hostarch.ByteOrder.PutUint16(buf[16:18], uint16(size)) buf[18] = dirent.Type copy(buf[19:], dirent.Name) // Zero out all remaining bytes in buf, including the NUL terminator @@ -136,9 +137,9 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { return syserror.EINVAL } buf = cb.t.CopyScratchBuffer(size) - usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino) - usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) - usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) + hostarch.ByteOrder.PutUint64(buf[0:8], dirent.Ino) + hostarch.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) + hostarch.ByteOrder.PutUint16(buf[16:18], uint16(size)) copy(buf[18:], dirent.Name) // Zero out all remaining bytes in buf, including the NUL terminator // after dirent.Name and the zero padding byte between the name and @@ -155,7 +156,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { // cb.remaining. return err } - cb.addr += usermem.Addr(n) + cb.addr += hostarch.Addr(n) cb.remaining -= n return nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go index 9d9dbf775..c961545f6 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mmap.go +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go @@ -21,7 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Mmap implements Linux syscall mmap(2). @@ -48,12 +49,12 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Unmap: fixed, Map32Bit: map32bit, Private: private, - Perms: usermem.AccessType{ + Perms: hostarch.AccessType{ Read: linux.PROT_READ&prot != 0, Write: linux.PROT_WRITE&prot != 0, Execute: linux.PROT_EXEC&prot != 0, }, - MaxPerms: usermem.AnyAccess, + MaxPerms: hostarch.AnyAccess, GrowsDown: linux.MAP_GROWSDOWN&flags != 0, Precommit: linux.MAP_POPULATE&flags != 0, } diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index 769c9b92f..dd93430e2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -20,7 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Mount implements Linux syscall mount(2). @@ -33,11 +34,11 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // For null-terminated strings related to mount(2), Linux copies in at most // a page worth of data. See fs/namespace.c:copy_mount_string(). - fsType, err := t.CopyInString(typeAddr, usermem.PageSize) + fsType, err := t.CopyInString(typeAddr, hostarch.PageSize) if err != nil { return 0, nil, err } - source, err := t.CopyInString(sourceAddr, usermem.PageSize) + source, err := t.CopyInString(sourceAddr, hostarch.PageSize) if err != nil { return 0, nil, err } @@ -53,7 +54,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // character placement, and the address is passed to each file system. // Most file systems always treat this data as a string, though, and so // do all of the ones we implement. - data, err = t.CopyInString(dataAddr, usermem.PageSize) + data, err = t.CopyInString(dataAddr, hostarch.PageSize) if err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go index 90a511d9a..2aaf1ed74 100644 --- a/pkg/sentry/syscalls/linux/vfs2/path.go +++ b/pkg/sentry/syscalls/linux/vfs2/path.go @@ -20,10 +20,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) -func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) { +func copyInPath(t *kernel.Task, addr hostarch.Addr) (fspath.Path, error) { pathname, err := t.CopyInString(addr, linux.PATH_MAX) if err != nil { return fspath.Path{}, err diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index 6986e39fe..c6fc1954c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -22,7 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Pipe implements Linux syscall pipe(2). @@ -38,7 +39,7 @@ func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, pipe2(t, addr, flags) } -func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { +func pipe2(t *kernel.Task, addr hostarch.Addr, flags int32) error { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return syserror.EINVAL } diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index c22e4ce54..a69c80edd 100644 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -25,8 +25,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + + "gvisor.dev/gvisor/pkg/hostarch" ) // fileCap is the maximum allowable files for poll & select. This has no @@ -158,7 +159,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. } // copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max. -func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) { +func copyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) { if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) { return nil, syserror.EINVAL } @@ -173,7 +174,7 @@ func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD return pfd, nil } -func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) { +func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) { pfd, err := copyInPollFDs(t, addr, nfds) if err != nil { return timeout, 0, err @@ -201,7 +202,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) } // CopyInFDSet copies an fd set from select(2)/pselect(2). -func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { +func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { set := make([]byte, nBytes) if addr != 0 { @@ -218,7 +219,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy return set, nil } -func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) { +func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) { if nfds < 0 || nfds > fileCap { return 0, syserror.EINVAL } @@ -368,7 +369,7 @@ func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) // copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr. // // startNs must be from CLOCK_MONOTONIC. -func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error { +func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr hostarch.Addr) error { if timeout <= 0 { return nil } @@ -381,7 +382,7 @@ func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.D // copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr. // // startNs must be from CLOCK_MONOTONIC. -func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error { +func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr hostarch.Addr) error { if timeout <= 0 { return nil } @@ -396,7 +397,7 @@ func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Du // // +stateify savable type pollRestartBlock struct { - pfdAddr usermem.Addr + pfdAddr hostarch.Addr nfds uint timeout time.Duration } @@ -406,7 +407,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) { return poll(t, p.pfdAddr, p.nfds, p.timeout) } -func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) { +func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) { remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout) // On an interrupt poll(2) is restarted with the remaining timeout. if err == syserror.EINTR { @@ -530,7 +531,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if _, err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil { return 0, nil, err } - if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil { + if err := setTempSignalSet(t, hostarch.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil { return 0, nil, err } } @@ -551,7 +552,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // returned value is the maximum that Duration will allow. // // If timespecAddr is NULL, the returned value is negative. -func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) { +func copyTimespecInToDuration(t *kernel.Task, timespecAddr hostarch.Addr) (time.Duration, error) { // Use a negative Duration to indicate "no timeout". timeout := time.Duration(-1) if timespecAddr != 0 { @@ -567,7 +568,7 @@ func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.D return timeout, nil } -func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error { +func setTempSignalSet(t *kernel.Task, maskAddr hostarch.Addr, maskSize uint) error { if maskAddr == 0 { return nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 903169dc2..c6330c21a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -23,7 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX @@ -43,7 +44,7 @@ func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, fchmodat(t, dirfd, pathAddr, mode) } -func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error { +func fchmodat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, mode uint) error { path, err := copyInPath(t, pathAddr) if err != nil { return err @@ -102,7 +103,7 @@ func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags) } -func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error { +func fchownat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, owner, group, flags int32) error { if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { return syserror.EINVAL } @@ -327,7 +328,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, followFinalSymlink, &opts) } -func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error { +func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr hostarch.Addr, opts *vfs.SetStatOptions) error { if timesAddr == 0 { opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME opts.Stat.Atime.Nsec = linux.UTIME_NOW @@ -391,7 +392,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts) } -func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error { +func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr hostarch.Addr, opts *vfs.SetStatOptions) error { if timesAddr == 0 { opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME opts.Stat.Atime.Nsec = linux.UTIME_NOW diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go index b89f34cdb..6163da103 100644 --- a/pkg/sentry/syscalls/linux/vfs2/signal.go +++ b/pkg/sentry/syscalls/linux/vfs2/signal.go @@ -21,11 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // sharedSignalfd is shared between the two calls. -func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) { +func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) { // Copy in the signal mask. mask, err := slinux.CopyInSigSet(t, sigset, sigsetsize) if err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 346fd1cea..936614eab 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -31,13 +31,9 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" -) - -// minListenBacklog is the minimum reasonable backlog for listening sockets. -const minListenBacklog = 8 -// maxListenBacklog is the maximum allowed backlog for listening sockets. -const maxListenBacklog = 1024 + "gvisor.dev/gvisor/pkg/hostarch" +) // maxAddrLen is the maximum socket address length we're willing to accept. const maxAddrLen = 200 @@ -50,6 +46,9 @@ const maxOptLen = 1024 * 8 // buffers upto INT_MAX. const maxControlLen = 10 * 1024 * 1024 +// maxListenBacklog is the maximum limit of listen backlog supported. +const maxListenBacklog = 1024 + // nameLenOffset is the offset from the start of the MessageHeader64 struct to // the NameLen field. const nameLenOffset = 8 @@ -116,7 +115,7 @@ type multipleMessageHeader64 struct { // CaptureAddress allocates memory for and copies a socket address structure // from the untrusted address space range. -func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { +func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) { if addrlen > maxAddrLen { return nil, syserror.EINVAL } @@ -132,7 +131,7 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, // writeAddress writes a sockaddr structure and its length to an output buffer // in the unstrusted address space range. If the address is bigger than the // buffer, it is truncated. -func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { +func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr hostarch.Addr, addrLenPtr hostarch.Addr) error { // Get the buffer length. var bufLen uint32 if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil { @@ -279,7 +278,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // accept is the implementation of the accept syscall. It is called by accept // and accept4 syscall handlers. -func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) { +func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) { // Check that no unsupported flags are passed in. if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { return 0, syserror.EINVAL @@ -369,7 +368,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Listen implements the linux syscall listen(2). func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() - backlog := args[1].Int() + backlog := args[1].Uint() // Get socket from the file descriptor. file := t.GetFileVFS2(fd) @@ -384,11 +383,13 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ENOTSOCK } - // Per Linux, the backlog is silently capped to reasonable values. - if backlog <= 0 { - backlog = minListenBacklog - } if backlog > maxListenBacklog { + // Linux treats incoming backlog as uint with a limit defined by + // sysctl_somaxconn. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666 + // + // We use the backlog to allocate a channel of that size, hence enforce + // a hard limit for the backlog. backlog = maxListenBacklog } @@ -475,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr hostarch.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -738,7 +739,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return uintptr(count), nil, nil } -func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { +func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr hostarch.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { // Capture the message header and io vectors. var msg MessageHeader64 if _, err := msg.CopyIn(t, msgPtr); err != nil { @@ -748,7 +749,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla if msg.IovLen > linux.UIO_MAXIOV { return 0, syserror.EMSGSIZE } - dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ + dst, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ AddressSpaceActive: true, }) if err != nil { @@ -799,7 +800,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla // Copy the address to the caller. if msg.NameLen != 0 { - if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil { + if err := writeAddress(t, sender, senderLen, hostarch.Addr(msg.Name), hostarch.Addr(msgPtr+nameLenOffset)); err != nil { return 0, err } } @@ -809,7 +810,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla return 0, err } if len(controlData) > 0 { - if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyOutBytes(hostarch.Addr(msg.Control), controlData); err != nil { return 0, err } } @@ -824,7 +825,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla // recvFrom is the implementation of the recvfrom syscall. It is called by // recvfrom and recv syscall handlers. -func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) { +func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) { if int(bufLen) < 0 { return 0, syserror.EINVAL } @@ -1000,7 +1001,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return uintptr(count), nil, nil } -func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) { +func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr hostarch.Addr, flags int32) (uintptr, error) { // Capture the message header. var msg MessageHeader64 if _, err := msg.CopyIn(t, msgPtr); err != nil { @@ -1014,7 +1015,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio return 0, syserror.ENOBUFS } controlData = make([]byte, msg.ControlLen) - if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyInBytes(hostarch.Addr(msg.Control), controlData); err != nil { return 0, err } } @@ -1023,7 +1024,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio var to []byte if msg.NameLen != 0 { var err error - to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen) + to, err = CaptureAddress(t, hostarch.Addr(msg.Name), msg.NameLen) if err != nil { return 0, err } @@ -1033,7 +1034,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio if msg.IovLen > linux.UIO_MAXIOV { return 0, syserror.EMSGSIZE } - src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ + src, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ AddressSpaceActive: true, }) if err != nil { @@ -1067,7 +1068,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio // sendTo is the implementation of the sendto syscall. It is called by sendto // and send syscall handlers. -func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) { +func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) { bl := int(bufLen) if bl < 0 { return 0, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go index 0f5d5189c..69e77fa99 100644 --- a/pkg/sentry/syscalls/linux/vfs2/stat.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -24,7 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // Stat implements Linux syscall stat(2). @@ -50,7 +51,7 @@ func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags) } -func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error { +func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr hostarch.Addr, flags int32) error { if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { return syserror.EINVAL } @@ -264,7 +265,7 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, accessAt(t, dirfd, addr, mode) } -func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error { +func accessAt(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, mode uint) error { const rOK = 4 const wOK = 2 const xOK = 1 @@ -312,7 +313,7 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return readlinkat(t, dirfd, pathAddr, bufAddr, size) } -func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) { +func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr hostarch.Addr, size uint) (uintptr, *kernel.SyscallControl, error) { if int(size) <= 0 { return 0, nil, syserror.EINVAL } diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index e05723ef9..c261050c6 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -23,7 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" + + "gvisor.dev/gvisor/pkg/hostarch" ) // ListXattr implements Linux syscall listxattr(2). @@ -291,7 +292,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. return 0, nil, file.RemoveXattr(t, name) } -func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { +func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) { name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) if err != nil { if err == syserror.ENAMETOOLONG { @@ -305,7 +306,7 @@ func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { return name, nil } -func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) { +func copyOutXattrNameList(t *kernel.Task, listAddr hostarch.Addr, size uint, names []string) (int, error) { if size > linux.XATTR_LIST_MAX { size = linux.XATTR_LIST_MAX } @@ -327,7 +328,7 @@ func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, name return t.CopyOutBytes(listAddr, buf.Bytes()) } -func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) { +func copyInXattrValue(t *kernel.Task, valueAddr hostarch.Addr, size uint) (string, error) { if size > linux.XATTR_SIZE_MAX { return "", syserror.E2BIG } @@ -338,7 +339,7 @@ func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string return gohacks.StringFromImmutableBytes(buf), nil } -func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) { +func copyOutXattrValue(t *kernel.Task, valueAddr hostarch.Addr, size uint, value string) (int, error) { if size > linux.XATTR_SIZE_MAX { size = linux.XATTR_SIZE_MAX } diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 87d8687ce..1f617ca8f 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -32,6 +32,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/gohacks", "//pkg/log", "//pkg/metric", "//pkg/sync", diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index df4990854..ac60fe8bf 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -99,6 +99,7 @@ go_library( "//pkg/fdnotifier", "//pkg/fspath", "//pkg/gohacks", + "//pkg/hostarch", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 3caf417ca..f48817132 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -20,10 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // NewAnonVirtualDentry returns a VirtualDentry with the given synthetic name, @@ -43,7 +43,7 @@ func (vfs *VirtualFilesystem) NewAnonVirtualDentry(name string) VirtualDentry { } const ( - anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super() + anonfsBlockSize = hostarch.PageSize // via fs/libfs.c:pseudo_fs_fill_super() // Mode, UID, and GID for a generic anonfs file. anonFileMode = 0600 // no type is correct diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 1556b41a3..b87d9690a 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -252,6 +252,9 @@ type WritableDynamicBytesSource interface { // are backed by a bytes.Buffer that is regenerated when necessary, consistent // with Linux's fs/seq_file.c:single_open(). // +// If data additionally implements WritableDynamicBytesSource, writes are +// dispatched to the implementer. The source data is not automatically modified. +// // DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first // use. // diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go index 2620cf975..15b234d61 100644 --- a/pkg/sentry/vfs/filesystem_impl_util.go +++ b/pkg/sentry/vfs/filesystem_impl_util.go @@ -18,7 +18,7 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) // GenericParseMountOptions parses a comma-separated list of options of the @@ -50,7 +50,7 @@ func GenericParseMountOptions(str string) map[string]string { func GenericStatFS(fsMagic uint64) linux.Statfs { return linux.Statfs{ Type: fsMagic, - BlockSize: usermem.PageSize, + BlockSize: hostarch.PageSize, NameLength: linux.NAME_MAX, } } diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 32fa01578..49d29e20b 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sync" @@ -256,7 +257,7 @@ func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallAr n += uint32(e.sizeOf()) } var buf [4]byte - usermem.ByteOrder.PutUint32(buf[:], n) + hostarch.ByteOrder.PutUint32(buf[:], n) _, err := uio.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{}) return 0, err @@ -683,10 +684,10 @@ func (e *Event) sizeOf() int { // construct the output. We use a buffer allocated ahead of time for // performance. buf must be at least inotifyEventBaseSize bytes. func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) { - usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd)) - usermem.ByteOrder.PutUint32(buf[4:], e.mask) - usermem.ByteOrder.PutUint32(buf[8:], e.cookie) - usermem.ByteOrder.PutUint32(buf[12:], e.len) + hostarch.ByteOrder.PutUint32(buf[0:], uint32(e.wd)) + hostarch.ByteOrder.PutUint32(buf[4:], e.mask) + hostarch.ByteOrder.PutUint32(buf[8:], e.cookie) + hostarch.ByteOrder.PutUint32(buf[12:], e.len) writeLen := 0 diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 922f9e697..7cdab6945 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -970,17 +970,22 @@ func superBlockOpts(mountPath string, mnt *Mount) string { opts += "," + mopts } - // NOTE(b/147673608): If the mount is a cgroup, we also need to include - // the cgroup name in the options. For now we just read that from the - // path. + // NOTE(b/147673608): If the mount is a ramdisk-based fake cgroupfs, we also + // need to include the cgroup name in the options. For now we just read that + // from the path. Note that this is only possible when "cgroup" isn't + // registered as a valid filesystem type. // - // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we - // should get this value from the cgroup itself, and not rely on the - // path. + // TODO(gvisor.dev/issue/190): Once we removed fake cgroupfs support, we + // should remove this. + if cgroupfs := mnt.vfs.getFilesystemType("cgroup"); cgroupfs != nil && cgroupfs.opts.AllowUserMount { + // Real cgroupfs available. + return opts + } if mnt.fs.FilesystemType().Name() == "cgroup" { splitPath := strings.Split(mountPath, "/") cgroupType := splitPath[len(splitPath)-1] opts += "," + cgroupType } + return opts } diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index b2c5229e7..8b3a11c64 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -43,6 +43,7 @@ go_template( ], deps = [ ":sync", + "//pkg/gohacks", ], ) diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go index 82b676abf..9578c9c52 100644 --- a/pkg/sync/generic_seqatomic_unsafe.go +++ b/pkg/sync/generic_seqatomic_unsafe.go @@ -10,6 +10,7 @@ package seqatomic import ( "unsafe" + "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/sync" ) @@ -39,7 +40,7 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) // runtime.RaceDisable() doesn't actually stop the race detector, so it // can't help us here. Instead, call runtime.memmove directly, which is // not instrumented by the race detector. - sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + gohacks.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { // This is ~40% faster for short reads than going through memmove. val = *ptr diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go index 158985709..39c766331 100644 --- a/pkg/sync/runtime_unsafe.go +++ b/pkg/sync/runtime_unsafe.go @@ -17,20 +17,6 @@ import ( "unsafe" ) -// Note that go:linkname silently doesn't work if the local name is exported, -// necessitating an indirection for exported functions. - -// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>. -// -//go:nosplit -func Memmove(to, from unsafe.Pointer, n uintptr) { - memmove(to, from, n) -} - -//go:linkname memmove runtime.memmove -//go:noescape -func memmove(to, from unsafe.Pointer, n uintptr) - // Gopark is runtime.gopark. Gopark calls unlockf(pointer to runtime.g, lock); // if unlockf returns true, Gopark blocks until Goready(pointer to runtime.g) // is called. unlockf and its callees must be nosplit and norace, since stack diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD index 5c38c783e..5f9164117 100644 --- a/pkg/sync/seqatomictest/BUILD +++ b/pkg/sync/seqatomictest/BUILD @@ -18,6 +18,7 @@ go_library( name = "seqatomic", srcs = ["seqatomic_int_unsafe.go"], deps = [ + "//pkg/gohacks", "//pkg/sync", ], ) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index f979d22f0..aa30cfc85 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,4 +1,5 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:deps.bzl", "deps_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -33,6 +34,36 @@ go_library( ], ) +deps_test( + name = "netstack_deps_test", + allowed = [ + "@com_github_google_btree//:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + "@org_golang_x_time//rate:go_default_library", + ], + allowed_prefixes = [ + "//", + "@org_golang_x_sys//internal/unsafeheader", + ], + targets = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/fdbased", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/packetsocket", + "//pkg/tcpip/link/qdisc/fifo", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + ], +) + go_test( name = "tcpip_test", size = "small", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index fef065b05..12c39dfa3 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -53,9 +53,8 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { t.Error("Not a valid IPv4 packet") } - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) + if !ipv4.IsChecksumValid() { + t.Errorf("Bad checksum, got = %d", ipv4.Checksum()) } for _, f := range checkers { @@ -400,18 +399,11 @@ func TCP(checkers ...TransportChecker) NetworkChecker { t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } - // Verify the checksum. tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) + payload := tcp.Payload() + payloadChecksum := header.Checksum(payload, 0) + if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) { + t.Errorf("Bad checksum, got = %d", tcp.Checksum()) } // Run the transport checkers. diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go index 52c22230e..33ff22a7b 100644 --- a/pkg/tcpip/hash/jenkins/jenkins.go +++ b/pkg/tcpip/hash/jenkins/jenkins.go @@ -42,26 +42,26 @@ func (s *Sum32) Reset() { *s = 0 } // Sum32 returns the hash value func (s *Sum32) Sum32() uint32 { - hash := *s + sCopy := *s - hash += (hash << 3) - hash ^= hash >> 11 - hash += hash << 15 + sCopy += sCopy << 3 + sCopy ^= sCopy >> 11 + sCopy += sCopy << 15 - return uint32(hash) + return uint32(sCopy) } // Write adds more data to the running hash. // // It never returns an error. func (s *Sum32) Write(data []byte) (int, error) { - hash := *s + sCopy := *s for _, b := range data { - hash += Sum32(b) - hash += hash << 10 - hash ^= hash >> 6 + sCopy += Sum32(b) + sCopy += sCopy << 10 + sCopy ^= sCopy >> 6 } - *s = hash + *s = sCopy return len(data), nil } diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 0bdc12d53..01240f5d0 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -52,6 +52,7 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -69,6 +70,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 3bc8b2b21..bf9ccbf1a 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -18,6 +18,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) func TestIsValidUnicastEthernetAddress(t *testing.T) { @@ -142,7 +143,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { } func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) { - addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a") + addr := testutil.MustParse6("ff02:304:506:708:90a:b0c:d0e:f1a") if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want { t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want) } diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go index b6126d29a..575604928 100644 --- a/pkg/tcpip/header/igmp_test.go +++ b/pkg/tcpip/header/igmp_test.go @@ -18,8 +18,8 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) // TestIGMPHeader tests the functions within header.igmp @@ -46,7 +46,7 @@ func TestIGMPHeader(t *testing.T) { t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) } - if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { + if got, want := igmpHeader.GroupAddress(), testutil.MustParse4("1.2.3.4"); got != want { t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) } @@ -71,7 +71,7 @@ func TestIGMPHeader(t *testing.T) { t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) } - groupAddress := tcpip.Address("\x04\x03\x02\x01") + groupAddress := testutil.MustParse4("4.3.2.1") igmpHeader.SetGroupAddress(groupAddress) if got := igmpHeader.GroupAddress(); got != groupAddress { t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index f588311e0..2be21ec75 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -178,6 +178,26 @@ const ( IPv4FlagDontFragment ) +// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined +// by RFC 3927 section 1. +var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as +// defined by RFC 5771 section 4. +var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00")) + if err != nil { + panic(err) + } + return subnet +}() + // IPv4EmptySubnet is the empty IPv4 subnet. var IPv4EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any)) @@ -423,6 +443,44 @@ func (b IPv4) IsValid(pktSize int) bool { return true } +// IsV4LinkLocalUnicastAddress determines if the provided address is an IPv4 +// link-local unicast address. +func IsV4LinkLocalUnicastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalUnicastSubnet.Contains(addr) +} + +// IsV4LinkLocalMulticastAddress determines if the provided address is an IPv4 +// link-local multicast address. +func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalMulticastSubnet.Contains(addr) +} + +// IsChecksumValid returns true iff the IPv4 header's checksum is valid. +func (b IPv4) IsChecksumValid() bool { + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + return b.CalculateChecksum() == 0xffff +} + // IsV4MulticastAddress determines if the provided address is an IPv4 multicast // address (range 224.0.0.0 to 239.255.255.255). The four most significant bits // will be 1110 = 0xe0. diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go index 6475cd694..c02fe898b 100644 --- a/pkg/tcpip/header/ipv4_test.go +++ b/pkg/tcpip/header/ipv4_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -177,3 +178,77 @@ func TestIPv4EncodeOptions(t *testing.T) { }) } } + +func TestIsV4LinkLocalUnicastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid (lowest)", + addr: "\xa9\xfe\x00\x00", + expected: true, + }, + { + name: "Valid (highest)", + addr: "\xa9\xfe\xff\xff", + expected: true, + }, + { + name: "Invalid (before subnet)", + addr: "\xa9\xfd\xff\xff", + expected: false, + }, + { + name: "Invalid (after subnet)", + addr: "\xa9\xff\x00\x00", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV4LinkLocalMulticastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid (lowest)", + addr: "\xe0\x00\x00\x00", + expected: true, + }, + { + name: "Valid (highest)", + addr: "\xe0\x00\x00\xff", + expected: true, + }, + { + name: "Invalid (before subnet)", + addr: "\xdf\xff\xff\xff", + expected: false, + }, + { + name: "Invalid (after subnet)", + addr: "\xe0\x00\x01\x00", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index f2403978c..c3a0407ac 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -98,12 +98,27 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - // IPv6AllRoutersMulticastAddress is a link-local multicast group that - // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local + // multicast group that all IPv6 routers MUST join, as per RFC 4291, section + // 2.8. Packets destined to this address will reach the router on an + // interface. + // + // The address is ff01::2. + IPv6AllRoutersInterfaceLocalMulticastAddress tcpip.Address = "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets // destined to this address will reach all routers on a link. // // The address is ff02::2. - IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + IPv6AllRoutersLinkLocalMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers in a site. + // + // The address is ff05::2. + IPv6AllRoutersSiteLocalMulticastAddress tcpip.Address = "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, // section 5: @@ -142,11 +157,6 @@ const ( // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, // within the byte holding the field, as per RFC 4291 section 2.7. ipv6MulticastAddressScopeMask = 0xF - - // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within - // a multicast IPv6 address that indicates the address has link-local scope, - // as per RFC 4291 section 2.7. - ipv6LinkLocalMulticastScope = 2 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -381,25 +391,25 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { return tcpip.Address(lladdrb[:]) } -// IsV6LinkLocalAddress determines if the provided address is an IPv6 -// link-local address (fe80::/10). -func IsV6LinkLocalAddress(addr tcpip.Address) bool { +// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6 +// link-local unicast address, as defined by RFC 4291 section 2.5.6. +func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool { if len(addr) != IPv6AddressSize { return false } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } -// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback -// address. +// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback +// address, as defined by RFC 4291 section 2.5.3. func IsV6LoopbackAddress(addr tcpip.Address) bool { return addr == IPv6Loopback } -// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 -// link-local multicast address. +// IsV6LinkLocalMulticastAddress returns true iff the provided address is an +// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7. func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { - return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope + return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope } // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier @@ -462,7 +472,7 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) { case IsV6LinkLocalMulticastAddress(addr): return LinkLocalScope, nil - case IsV6LinkLocalAddress(addr): + case IsV6LinkLocalUnicastAddress(addr): return LinkLocalScope, nil default: @@ -520,3 +530,46 @@ func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) PrefixLen: IIDOffsetInIPv6Address * 8, } } + +// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by +// RFC 7346 section 2. +type IPv6MulticastScope uint8 + +// The various values for IPv6 multicast scopes, as per RFC 7346 section 2: +// +// +------+--------------------------+-------------------------+ +// | scop | NAME | REFERENCE | +// +------+--------------------------+-------------------------+ +// | 0 | Reserved | [RFC4291], RFC 7346 | +// | 1 | Interface-Local scope | [RFC4291], RFC 7346 | +// | 2 | Link-Local scope | [RFC4291], RFC 7346 | +// | 3 | Realm-Local scope | [RFC4291], RFC 7346 | +// | 4 | Admin-Local scope | [RFC4291], RFC 7346 | +// | 5 | Site-Local scope | [RFC4291], RFC 7346 | +// | 6 | Unassigned | | +// | 7 | Unassigned | | +// | 8 | Organization-Local scope | [RFC4291], RFC 7346 | +// | 9 | Unassigned | | +// | A | Unassigned | | +// | B | Unassigned | | +// | C | Unassigned | | +// | D | Unassigned | | +// | E | Global scope | [RFC4291], RFC 7346 | +// | F | Reserved | [RFC4291], RFC 7346 | +// +------+--------------------------+-------------------------+ +const ( + IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0) + IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1) + IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2) + IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3) + IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4) + IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5) + IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8) + IPv6GlobalMulticastScope = IPv6MulticastScope(0xE) + IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF) +) + +// V6MulticastScope returns the scope of a multicast address. +func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope { + return IPv6MulticastScope(addr[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index f10f446a6..89be84068 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -24,15 +24,17 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") +const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + +var ( + linkLocalAddr = testutil.MustParse6("fe80::1") + linkLocalMulticastAddr = testutil.MustParse6("ff02::1") + uniqueLocalAddr1 = testutil.MustParse6("fc00::1") + uniqueLocalAddr2 = testutil.MustParse6("fd00::2") + globalAddr = testutil.MustParse6("a000::1") ) func TestEthernetAdddressToModifiedEUI64(t *testing.T) { @@ -50,7 +52,7 @@ func TestEthernetAdddressToModifiedEUI64(t *testing.T) { } func TestLinkLocalAddr(t *testing.T) { - if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + if got, want := header.LinkLocalAddr(linkAddr), testutil.MustParse6("fe80::2:3ff:fe04:506"); got != want { t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) } } @@ -252,7 +254,7 @@ func TestIsV6LinkLocalMulticastAddress(t *testing.T) { } } -func TestIsV6LinkLocalAddress(t *testing.T) { +func TestIsV6LinkLocalUnicastAddress(t *testing.T) { tests := []struct { name string addr tcpip.Address @@ -287,8 +289,8 @@ func TestIsV6LinkLocalAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) } }) } @@ -373,3 +375,83 @@ func TestSolicitedNodeAddr(t *testing.T) { }) } } + +func TestV6MulticastScope(t *testing.T) { + tests := []struct { + addr tcpip.Address + want header.IPv6MulticastScope + }{ + { + addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6Reserved0MulticastScope, + }, + { + addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6InterfaceLocalMulticastScope, + }, + { + addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6LinkLocalMulticastScope, + }, + { + addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6RealmLocalMulticastScope, + }, + { + addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6AdminLocalMulticastScope, + }, + { + addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6SiteLocalMulticastScope, + }, + { + addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(6), + }, + { + addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(7), + }, + { + addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6OrganizationLocalMulticastScope, + }, + { + addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(9), + }, + { + addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(10), + }, + { + addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(11), + }, + { + addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(12), + }, + { + addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(13), + }, + { + addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6GlobalMulticastScope, + }, + { + addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6ReservedFMulticastScope, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { + if got := header.V6MulticastScope(test.addr); got != test.want { + t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want) + } + }) + } +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index d0a1a2492..1b5093e58 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) // TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. @@ -40,13 +41,13 @@ func TestNDPNeighborSolicit(t *testing.T) { // Test getting the Target Address. ns := NDPNeighborSolicit(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10") if got := ns.TargetAddress(); got != addr { t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) } // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11") ns.SetTargetAddress(addr2) if got := ns.TargetAddress(); got != addr2 { t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) @@ -69,7 +70,7 @@ func TestNDPNeighborAdvert(t *testing.T) { // Test getting the Target Address. na := NDPNeighborAdvert(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10") if got := na.TargetAddress(); got != addr { t.Errorf("got TargetAddress = %s, want %s", got, addr) } @@ -90,7 +91,7 @@ func TestNDPNeighborAdvert(t *testing.T) { } // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11") na.SetTargetAddress(addr2) if got := na.TargetAddress(); got != addr2 { t.Errorf("got TargetAddress = %s, want %s", got, addr2) @@ -277,7 +278,7 @@ func TestOpts(t *testing.T) { } const validLifetimeSeconds = 16909060 - const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18") + address := testutil.MustParse6("90a:b0c:d0e:f10:1112:1314:1516:1718") expectedRDNSSBytes := [...]byte{ // Type, Length diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index adc835d30..0df517000 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -216,104 +216,104 @@ const ( TCPDefaultMSS = 536 ) -// SourcePort returns the "source port" field of the tcp header. +// SourcePort returns the "source port" field of the TCP header. func (b TCP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[TCPSrcPortOffset:]) } -// DestinationPort returns the "destination port" field of the tcp header. +// DestinationPort returns the "destination port" field of the TCP header. func (b TCP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[TCPDstPortOffset:]) } -// SequenceNumber returns the "sequence number" field of the tcp header. +// SequenceNumber returns the "sequence number" field of the TCP header. func (b TCP) SequenceNumber() uint32 { return binary.BigEndian.Uint32(b[TCPSeqNumOffset:]) } -// AckNumber returns the "ack number" field of the tcp header. +// AckNumber returns the "ack number" field of the TCP header. func (b TCP) AckNumber() uint32 { return binary.BigEndian.Uint32(b[TCPAckNumOffset:]) } -// DataOffset returns the "data offset" field of the tcp header. The return +// DataOffset returns the "data offset" field of the TCP header. The return // value is the length of the TCP header in bytes. func (b TCP) DataOffset() uint8 { return (b[TCPDataOffset] >> 4) * 4 } -// Payload returns the data in the tcp packet. +// Payload returns the data in the TCP packet. func (b TCP) Payload() []byte { return b[b.DataOffset():] } -// Flags returns the flags field of the tcp header. +// Flags returns the flags field of the TCP header. func (b TCP) Flags() TCPFlags { return TCPFlags(b[TCPFlagsOffset]) } -// WindowSize returns the "window size" field of the tcp header. +// WindowSize returns the "window size" field of the TCP header. func (b TCP) WindowSize() uint16 { return binary.BigEndian.Uint16(b[TCPWinSizeOffset:]) } -// Checksum returns the "checksum" field of the tcp header. +// Checksum returns the "checksum" field of the TCP header. func (b TCP) Checksum() uint16 { return binary.BigEndian.Uint16(b[TCPChecksumOffset:]) } -// UrgentPointer returns the "urgent pointer" field of the tcp header. +// UrgentPointer returns the "urgent pointer" field of the TCP header. func (b TCP) UrgentPointer() uint16 { return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:]) } -// SetSourcePort sets the "source port" field of the tcp header. +// SetSourcePort sets the "source port" field of the TCP header. func (b TCP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port) } -// SetDestinationPort sets the "destination port" field of the tcp header. +// SetDestinationPort sets the "destination port" field of the TCP header. func (b TCP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port) } -// SetChecksum sets the checksum field of the tcp header. +// SetChecksum sets the checksum field of the TCP header. func (b TCP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum) } -// SetDataOffset sets the data offset field of the tcp header. headerLen should +// SetDataOffset sets the data offset field of the TCP header. headerLen should // be the length of the TCP header in bytes. func (b TCP) SetDataOffset(headerLen uint8) { b[TCPDataOffset] = (headerLen / 4) << 4 } -// SetSequenceNumber sets the sequence number field of the tcp header. +// SetSequenceNumber sets the sequence number field of the TCP header. func (b TCP) SetSequenceNumber(seqNum uint32) { binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum) } -// SetAckNumber sets the ack number field of the tcp header. +// SetAckNumber sets the ack number field of the TCP header. func (b TCP) SetAckNumber(ackNum uint32) { binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum) } -// SetFlags sets the flags field of the tcp header. +// SetFlags sets the flags field of the TCP header. func (b TCP) SetFlags(flags uint8) { b[TCPFlagsOffset] = flags } -// SetWindowSize sets the window size field of the tcp header. +// SetWindowSize sets the window size field of the TCP header. func (b TCP) SetWindowSize(rcvwnd uint16) { binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } -// SetUrgentPoiner sets the window size field of the tcp header. +// SetUrgentPoiner sets the window size field of the TCP header. func (b TCP) SetUrgentPoiner(urgentPointer uint16) { binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer) } -// CalculateChecksum calculates the checksum of the tcp segment. +// CalculateChecksum calculates the checksum of the TCP segment. // partialChecksum is the checksum of the network-layer pseudo-header // and the checksum of the segment data. func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { @@ -321,6 +321,13 @@ func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { return Checksum(b[:b.DataOffset()], partialChecksum) } +// IsChecksumValid returns true iff the TCP header's checksum is valid. +func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool { + xsum := PseudoHeaderChecksum(TCPProtocolNumber, src, dst, uint16(b.DataOffset())+payloadLength) + xsum = ChecksumCombine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + // Options returns a slice that holds the unparsed TCP options in the segment. func (b TCP) Options() []byte { return b[TCPMinimumSize:b.DataOffset()] @@ -340,7 +347,7 @@ func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } -// Encode encodes all the fields of the tcp header. +// Encode encodes all the fields of the TCP header. func (b TCP) Encode(t *TCPFields) { b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize) binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort) @@ -350,7 +357,7 @@ func (b TCP) Encode(t *TCPFields) { binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer) } -// EncodePartial updates a subset of the fields of the tcp header. It is useful +// EncodePartial updates a subset of the fields of the TCP header. It is useful // in cases when similar segments are produced. func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { // Add the total length and "flags" field contributions to the checksum. @@ -374,7 +381,7 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 } // ParseSynOptions parses the options received in a SYN segment and returns the -// relevant ones. opts should point to the option part of the TCP Header. +// relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { limit := len(opts) diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 98bdd29db..ae9d167ff 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -64,17 +64,17 @@ const ( UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) -// SourcePort returns the "source port" field of the udp header. +// SourcePort returns the "source port" field of the UDP header. func (b UDP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } -// DestinationPort returns the "destination port" field of the udp header. +// DestinationPort returns the "destination port" field of the UDP header. func (b UDP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } -// Length returns the "length" field of the udp header. +// Length returns the "length" field of the UDP header. func (b UDP) Length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } @@ -84,39 +84,46 @@ func (b UDP) Payload() []byte { return b[UDPMinimumSize:] } -// Checksum returns the "checksum" field of the udp header. +// Checksum returns the "checksum" field of the UDP header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) } -// SetSourcePort sets the "source port" field of the udp header. +// SetSourcePort sets the "source port" field of the UDP header. func (b UDP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[udpSrcPort:], port) } -// SetDestinationPort sets the "destination port" field of the udp header. +// SetDestinationPort sets the "destination port" field of the UDP header. func (b UDP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[udpDstPort:], port) } -// SetChecksum sets the "checksum" field of the udp header. +// SetChecksum sets the "checksum" field of the UDP header. func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } -// SetLength sets the "length" field of the udp header. +// SetLength sets the "length" field of the UDP header. func (b UDP) SetLength(length uint16) { binary.BigEndian.PutUint16(b[udpLength:], length) } -// CalculateChecksum calculates the checksum of the udp packet, given the +// CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. return Checksum(b[:UDPMinimumSize], partialChecksum) } -// Encode encodes all the fields of the udp header. +// IsChecksumValid returns true iff the UDP header's checksum is valid. +func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool { + xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length()) + xsum = ChecksumCombine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + +// Encode encodes all the fields of the UDP header. func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index fa8814bac..7b1ff44f4 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -21,6 +21,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d59d678b2..6905b9ccb 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -33,6 +33,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 018d6a578..9b3714f9e 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -30,20 +30,16 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) const ( nicID = 1 - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") + stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - unknownAddr = tcpip.Address("\x0a\x00\x00\x03") - defaultChannelSize = 1 defaultMTU = 65536 @@ -54,6 +50,12 @@ const ( eventChanSize = 32 ) +var ( + stackAddr = testutil.MustParse4("10.0.0.1") + remoteAddr = testutil.MustParse4("10.0.0.2") + unknownAddr = testutil.MustParse4("10.0.0.3") +) + type eventType uint8 const ( diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index b9f129728..ac35d81e7 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -156,14 +156,6 @@ type GenericMulticastProtocolOptions struct { // // Unsolicited reports are transmitted when a group is newly joined. MaxUnsolicitedReportDelay time.Duration - - // AllNodesAddress is a multicast address that all nodes on a network should - // be a member of. - // - // This address will not have the generic multicast protocol performed on it; - // it will be left in the non member/listener state, and packets will never - // be sent for it. - AllNodesAddress tcpip.Address } // MulticastGroupProtocol is a multicast group protocol whose core state machine @@ -188,6 +180,10 @@ type MulticastGroupProtocol interface { // SendLeave sends a multicast leave for the specified group address. SendLeave(groupAddress tcpip.Address) tcpip.Error + + // ShouldPerformProtocol returns true iff the protocol should be performed for + // the specified group. + ShouldPerformProtocol(tcpip.Address) bool } // GenericMulticastProtocolState is the per interface generic multicast protocol @@ -455,20 +451,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t info.lastToSendReport = false - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { info.state = idleMember return } @@ -537,20 +520,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres return } - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { return } @@ -627,20 +597,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr return } - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { return } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 381460c82..0b51563cd 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -43,6 +43,8 @@ type mockMulticastGroupProtocolProtectedFields struct { type mockMulticastGroupProtocol struct { t *testing.T + skipProtocolAddress tcpip.Address + mu mockMulticastGroupProtocolProtectedFields } @@ -165,6 +167,11 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip return nil } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + return groupAddress != m.skipProtocolAddress +} + func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { m.mu.Lock() defer m.mu.Unlock() @@ -193,10 +200,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr cmp.FilterPath( func(p cmp.Path) bool { switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress": return true + default: + return false } - return false }, cmp.Ignore(), ), @@ -225,14 +233,13 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(0)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, }) // Joining a group should send a report immediately and another after @@ -279,14 +286,13 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(1)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, }) mgp.joinGroup(test.addr) @@ -356,14 +362,13 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(2)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) @@ -446,14 +451,13 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) @@ -574,14 +578,13 @@ func TestJoinCount(t *testing.T) { } func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index b6f39ddb1..d06b26309 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -21,53 +21,56 @@ import "gvisor.dev/gvisor/pkg/tcpip" // MultiCounterIPStats holds IP statistics, each counter may have several // versions. type MultiCounterIPStats struct { - // PacketsReceived is the total number of IP packets received from the link - // layer. + // PacketsReceived is the number of IP packets received from the link layer. PacketsReceived tcpip.MultiCounterStat - // DisabledPacketsReceived is the total number of IP packets received from the - // link layer when the IP layer is disabled. + // DisabledPacketsReceived is the number of IP packets received from the link + // layer when the IP layer is disabled. DisabledPacketsReceived tcpip.MultiCounterStat - // InvalidDestinationAddressesReceived is the total number of IP packets - // received with an unknown or invalid destination address. + // InvalidDestinationAddressesReceived is the number of IP packets received + // with an unknown or invalid destination address. InvalidDestinationAddressesReceived tcpip.MultiCounterStat - // InvalidSourceAddressesReceived is the total number of IP packets received - // with a source address that should never have been received on the wire. + // InvalidSourceAddressesReceived is the number of IP packets received with a + // source address that should never have been received on the wire. InvalidSourceAddressesReceived tcpip.MultiCounterStat - // PacketsDelivered is the total number of incoming IP packets that are - // successfully delivered to the transport layer. + // PacketsDelivered is the number of incoming IP packets that are successfully + // delivered to the transport layer. PacketsDelivered tcpip.MultiCounterStat - // PacketsSent is the total number of IP packets sent via WritePacket. + // PacketsSent is the number of IP packets sent via WritePacket. PacketsSent tcpip.MultiCounterStat - // OutgoingPacketErrors is the total number of IP packets which failed to - // write to a link-layer endpoint. + // OutgoingPacketErrors is the number of IP packets which failed to write to a + // link-layer endpoint. OutgoingPacketErrors tcpip.MultiCounterStat - // MalformedPacketsReceived is the total number of IP Packets that were - // dropped due to the IP packet header failing validation checks. + // MalformedPacketsReceived is the number of IP Packets that were dropped due + // to the IP packet header failing validation checks. MalformedPacketsReceived tcpip.MultiCounterStat - // MalformedFragmentsReceived is the total number of IP Fragments that were - // dropped due to the fragment failing validation checks. + // MalformedFragmentsReceived is the number of IP Fragments that were dropped + // due to the fragment failing validation checks. MalformedFragmentsReceived tcpip.MultiCounterStat - // IPTablesPreroutingDropped is the total number of IP packets dropped in the + // IPTablesPreroutingDropped is the number of IP packets dropped in the // Prerouting chain. IPTablesPreroutingDropped tcpip.MultiCounterStat - // IPTablesInputDropped is the total number of IP packets dropped in the Input + // IPTablesInputDropped is the number of IP packets dropped in the Input // chain. IPTablesInputDropped tcpip.MultiCounterStat - // IPTablesOutputDropped is the total number of IP packets dropped in the - // Output chain. + // IPTablesOutputDropped is the number of IP packets dropped in the Output + // chain. IPTablesOutputDropped tcpip.MultiCounterStat + // IPTablesPostroutingDropped is the number of IP packets dropped in the + // Postrouting chain. + IPTablesPostroutingDropped tcpip.MultiCounterStat + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out // of IPStats. @@ -98,6 +101,7 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) + m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped) m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived) m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived) m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a4edc69c7..dbd674634 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "fmt" "strings" "testing" @@ -29,23 +30,25 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -const ( - localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") - remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") - ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") - ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") - ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") - localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") - ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") - ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - nicID = 1 +const nicID = 1 + +var ( + localIPv4Addr = testutil.MustParse4("10.0.0.1") + remoteIPv4Addr = testutil.MustParse4("10.0.0.2") + ipv4SubnetAddr = testutil.MustParse4("10.0.0.0") + ipv4SubnetMask = testutil.MustParse4("255.255.255.0") + ipv4Gateway = testutil.MustParse4("10.0.0.3") + localIPv6Addr = testutil.MustParse6("a00::1") + remoteIPv6Addr = testutil.MustParse6("a00::2") + ipv6SubnetAddr = testutil.MustParse6("a00::") + ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00") + ipv6Gateway = testutil.MustParse6("a00::3") ) var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ @@ -1938,3 +1941,80 @@ func TestICMPInclusionSize(t *testing.T) { }) } } + +func TestJoinLeaveAllRoutersGroup(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + protoFactory stack.NetworkProtocolFactory + allRoutersAddr tcpip.Address + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + protoFactory: ipv4.NewProtocol, + allRoutersAddr: header.IPv4AllRoutersGroup, + }, + { + name: "IPv6 Interface Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress, + }, + { + name: "IPv6 Link Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress, + }, + { + name: "IPv6 Site Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, nicDisabled := range [...]bool{true, false} { + t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + opts := stack.NICOptions{Disabled: nicDisabled} + if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) + } + + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if !got { + t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, false); err != nil { + t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 5e7f10f4b..7ee0495d9 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -45,6 +45,7 @@ go_test( "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index f3fc1c87e..b1ac29294 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -126,6 +126,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error { return err } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (igmp *igmpState) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + // As per RFC 2236 section 6 page 10, + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + return groupAddress != header.IPv4AllSystems +} + // init sets up an igmpState struct, and is required to be called before using // a new igmpState. // @@ -137,7 +148,6 @@ func (igmp *igmpState) init(ep *endpoint) { Clock: ep.protocol.stack.Clock(), Protocol: igmp, MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, - AllNodesAddress: header.IPv4AllSystems, }) igmp.igmpV1Present = igmpV1PresentDefault igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index e5e1b89cc..4bd6f462e 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -26,18 +26,22 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") - multicastAddr = tcpip.Address("\xe0\x00\x00\x03") nicID = 1 defaultTTL = 1 defaultPrefixLength = 24 ) +var ( + stackAddr = testutil.MustParse4("10.0.0.1") + remoteAddr = testutil.MustParse4("10.0.0.2") + multicastAddr = testutil.MustParse4("224.0.0.3") +) + // validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet // sent to the provided address with the passed fields set. Raises a t.Error if // any field does not match. @@ -292,7 +296,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPLeaveGroup, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, @@ -302,7 +306,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPMembershipQuery, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: true, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() }, @@ -312,7 +316,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv1MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, @@ -322,7 +326,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv2MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, @@ -332,7 +336,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv2MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{ - {Address: tcpip.Address("\x0a\x00\x0f\x01"), PrefixLen: 24}, + {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24}, {Address: stackAddr, PrefixLen: 24}, }, srcAddr: remoteAddr, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1a5661ca4..a82a5790d 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -150,6 +150,38 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if forwarding { + // There does not seem to be an RFC requirement for a node to join the all + // routers multicast address but + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml + // specifies the address as a group for all routers on a subnet so we join + // the group here. + if err := e.joinGroupLocked(header.IPv4AllRoutersGroup); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } + + return + } + + switch err := e.leaveGroupLocked(header.IPv4AllRoutersGroup).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } +} + // Enable implements stack.NetworkEndpoint. func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() @@ -226,7 +258,7 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) { + switch err := e.leaveGroupLocked(header.IPv4AllSystems).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) @@ -383,6 +415,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + // Postrouting NAT can only change the source address, and does not alter the + // route or outgoing interface of the packet. + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesPostroutingDropped.Increment() + return nil + } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) @@ -454,9 +495,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) - stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - for pkt := range dropped { + outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) + for pkt := range outputDropped { pkts.Remove(pkt) } @@ -478,6 +519,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } + // We ignore the list of NAT-ed packets here because Postrouting NAT can only + // change the source address, and does not alter the route or outgoing + // interface of the packet. + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName) + stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) + for pkt := range postroutingDropped { + pkts.Remove(pkt) + } + // The rest of the packets can be delivered to the NIC as a batch. pktsLen := pkts.Len() written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) @@ -485,7 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) // Dropped packets aren't errors, so include them in the return value. - return locallyDelivered + written + len(dropped), err + return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. @@ -551,6 +601,22 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // forwardPacket attempts to forward a packet to its final destination. func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv4(pkt.NetworkHeader().View()) + + dstAddr := h.DestinationAddress() + if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { + // As per RFC 3927 section 7, + // + // A router MUST NOT forward a packet with an IPv4 Link-Local source or + // destination address, irrespective of the router's default route + // configuration or routes obtained from dynamic routing protocols. + // + // A router which receives a packet with an IPv4 Link-Local source or + // destination address MUST NOT forward the packet. This prevents + // forwarding of packets back onto the network segment from which they + // originated, or to any other segment. + return nil + } + ttl := h.TTL() if ttl == 0 { // As per RFC 792 page 6, Time Exceeded Message, @@ -589,8 +655,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } } - dstAddr := h.DestinationAddress() - // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { ep.handleValidatedPacket(h, pkt) @@ -1114,28 +1178,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) return nil, false } - // There has been some confusion regarding verifying checksums. We need - // just look for negative 0 (0xffff) as the checksum, as it's not possible to - // get positive 0 (0) for the checksum. Some bad implementations could get it - // when doing entry replacement in the early days of the Internet, - // however the lore that one needs to check for both persists. - // - // RFC 1624 section 1 describes the source of this confusion as: - // [the partial recalculation method described in RFC 1071] computes a - // result for certain cases that differs from the one obtained from - // scratch (one's complement of one's complement sum of the original - // fields). - // - // However RFC 1624 section 5 clarifies that if using the verification method - // "recommended by RFC 1071, it does not matter if an intermediate system - // generated a -0 instead of +0". - // - // RFC1071 page 1 specifies the verification method as: - // (3) To check a checksum, the 1's complement sum is computed over the - // same set of octets, including the checksum field. If the result - // is all 1 bits (-0 in 1's complement arithmetic), the check - // succeeds. - if h.CalculateChecksum() != 0xffff { + if !h.IsChecksumValid() { return nil, false } @@ -1168,12 +1211,27 @@ func (p *protocol) Forwarding() bool { return uint8(atomic.LoadUint32(&p.forwarding)) == 1 } +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) + } + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) +} + // SetForwarding implements stack.ForwardingNetworkProtocol. func (p *protocol) SetForwarding(v bool) { - if v { - atomic.StoreUint32(&p.forwarding, 1) - } else { - atomic.StoreUint32(&p.forwarding, 0) + p.mu.Lock() + defer p.mu.Unlock() + + if !p.setForwarding(v) { + return + } + + for _, ep := range p.mu.eps { + ep.transitionForwarding(v) } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index eba91c68c..d49dff4d5 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -39,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -2612,34 +2613,36 @@ func TestWriteStats(t *testing.T) { const nPackets = 3 tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectOutputDropped int + expectPostroutingDropped int + expectWritten int }{ { name: "Accept all", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { name: "Accept all with error", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets - 1, }, { - name: "Drop all", + name: "Drop all with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] @@ -2648,16 +2651,32 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %s", err) } }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: nPackets, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { - name: "Drop some", + name: "Drop all with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: 0, + expectPostroutingDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule that matches only 1 // of the 3 packets. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) // We'll match and DROP the last packet. @@ -2670,10 +2689,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %s", err) } }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 1, + expectPostroutingDropped: 0, + expectWritten: nPackets, + }, { + name: "Drop some with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Postrouting DROP rule that matches only 1 + // of the 3 packets. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 1, + expectWritten: nPackets, }, } @@ -2724,13 +2766,16 @@ func TestWriteStats(t *testing.T) { nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { + t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { + t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) } if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) } }) } @@ -2995,12 +3040,14 @@ func TestCloseLocking(t *testing.T) { nicID1 = 1 nicID2 = 2 - src = tcpip.Address("\x10\x00\x00\x01") - dst = tcpip.Address("\x10\x00\x00\x02") - iterations = 1000 ) + var ( + src = tcptestutil.MustParse4("16.0.0.1") + dst = tcptestutil.MustParse4("16.0.0.2") + ) + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index bb9a02ed0..db998e83e 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -66,5 +66,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index a142b76c1..b2a80e1e9 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -273,7 +273,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if iph.HopLimit() != header.MLDHopLimit { return false } - if !header.IsV6LinkLocalAddress(iph.SourceAddress()) { + if !header.IsV6LinkLocalUnicastAddress(iph.SourceAddress()) { return false } return true @@ -804,7 +804,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r routerAddr := srcAddr // Is the IP Source Address a link-local address? - if !header.IsV6LinkLocalAddress(routerAddr) { + if !header.IsV6LinkLocalUnicastAddress(routerAddr) { // ...No, silently drop the packet. received.invalid.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c6d9d8f0d..2e515379c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -314,7 +314,7 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { // Snooping switches MUST manage multicast forwarding state based on MLD // Report and Done messages sent with the unspecified address as the // IPv6 source address. - if header.IsV6LinkLocalAddress(addr) { + if header.IsV6LinkLocalUnicastAddress(addr) { e.mu.mld.sendQueuedReports() } } @@ -410,22 +410,65 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t // // Must only be called when the forwarding status changes. func (e *endpoint) transitionForwarding(forwarding bool) { + allRoutersGroups := [...]tcpip.Address{ + header.IPv6AllRoutersInterfaceLocalMulticastAddress, + header.IPv6AllRoutersLinkLocalMulticastAddress, + header.IPv6AllRoutersSiteLocalMulticastAddress, + } + e.mu.Lock() defer e.mu.Unlock() - if !e.Enabled() { - return - } - if forwarding { // When transitioning into an IPv6 router, host-only state (NDP discovered // routers, discovered on-link prefixes, and auto-generated addresses) is // cleaned up/invalidated and NDP router solicitations are stopped. e.mu.ndp.stopSolicitingRouters() e.mu.ndp.cleanupState(true /* hostOnly */) - } else { - // When transitioning into an IPv6 host, NDP router solicitations are - // started. + + // As per RFC 4291 section 2.8: + // + // A router is required to recognize all addresses that a host is + // required to recognize, plus the following addresses as identifying + // itself: + // + // o The All-Routers multicast addresses defined in Section 2.7.1. + // + // As per RFC 4291 section 2.7.1, + // + // All Routers Addresses: FF01:0:0:0:0:0:0:2 + // FF02:0:0:0:0:0:0:2 + // FF05:0:0:0:0:0:0:2 + // + // The above multicast addresses identify the group of all IPv6 routers, + // within scope 1 (interface-local), 2 (link-local), or 5 (site-local). + for _, g := range allRoutersGroups { + if err := e.joinGroupLocked(g); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err)) + } + } + + return + } + + for _, g := range allRoutersGroups { + switch err := e.leaveGroupLocked(g).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } + } + + // When transitioning into an IPv6 host, NDP router solicitations are + // started if the endpoint is enabled. + // + // If the endpoint is not currently enabled, routers will be solicited when + // the endpoint becomes enabled (if it is still a host). + if e.Enabled() { e.mu.ndp.startSolicitingRouters() } } @@ -573,7 +616,7 @@ func (e *endpoint) disableLocked() { e.mu.ndp.cleanupState(false /* hostOnly */) // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { + switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) @@ -726,6 +769,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + // Postrouting NAT can only change the source address, and does not alter the + // route or outgoing interface of the packet. + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesPostroutingDropped.Increment() + return nil + } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { @@ -797,9 +849,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) - stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - for pkt := range dropped { + outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) + for pkt := range outputDropped { pkts.Remove(pkt) } @@ -820,6 +872,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe locallyDelivered++ } + // We ignore the list of NAT-ed packets here because Postrouting NAT can only + // change the source address, and does not alter the route or outgoing + // interface of the packet. + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName) + stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) + for pkt := range postroutingDropped { + pkts.Remove(pkt) + } + // The rest of the packets can be delivered to the NIC as a batch. pktsLen := pkts.Len() written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) @@ -827,7 +888,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) // Dropped packets aren't errors, so include them in the return value. - return locallyDelivered + written + len(dropped), err + return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. @@ -869,6 +930,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // forwardPacket attempts to forward a packet to its final destination. func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv6(pkt.NetworkHeader().View()) + + dstAddr := h.DestinationAddress() + if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { + // As per RFC 4291 section 2.5.6, + // + // Routers must not forward any packets with Link-Local source or + // destination addresses to other links. + return nil + } + hopLimit := h.HopLimit() if hopLimit <= 1 { // As per RFC 4443 section 3.3, @@ -881,8 +952,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) } - dstAddr := h.DestinationAddress() - // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { ep.handleValidatedPacket(h, pkt) @@ -1571,7 +1640,7 @@ func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { var linkLocalAddr tcpip.Address e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { if addressEndpoint.IsAssigned(false /* allowExpired */) { - if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalUnicastAddress(addr) { linkLocalAddr = addr return false } @@ -1979,9 +2048,9 @@ func (p *protocol) Forwarding() bool { // Returns true if the forwarding status was updated. func (p *protocol) setForwarding(v bool) bool { if v { - return atomic.SwapUint32(&p.forwarding, 1) == 0 + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) } - return atomic.SwapUint32(&p.forwarding, 0) == 1 + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) } // SetForwarding implements stack.ForwardingNetworkProtocol. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index c206cebeb..a620e9ad9 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -2468,34 +2468,36 @@ func TestFragmentReassemblyTimeout(t *testing.T) { func TestWriteStats(t *testing.T) { const nPackets = 3 tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectOutputDropped int + expectPostroutingDropped int + expectWritten int }{ { name: "Accept all", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { name: "Accept all with error", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets - 1, }, { - name: "Drop all", + name: "Drop all with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] @@ -2504,16 +2506,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: nPackets, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { - name: "Drop some", + name: "Drop all with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: 0, + expectPostroutingDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule that matches only 1 // of the 3 packets. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) // We'll match and DROP the last packet. @@ -2526,10 +2545,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 1, + expectPostroutingDropped: 0, + expectWritten: nPackets, + }, { + name: "Drop some with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Postrouting DROP rule that matches only 1 + // of the 3 packets. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 1, + expectWritten: nPackets, }, } @@ -2578,13 +2620,16 @@ func TestWriteStats(t *testing.T) { nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { + t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { + t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) } if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) } }) } diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index dd153466d..165b7d2d2 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -76,10 +76,29 @@ func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) // // Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error { - _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) return err } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (mld *mldState) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + // As per RFC 2710 section 5 page 10, + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + // + // MLD messages are never sent for multicast addresses whose scope is 0 + // (reserved) or 1 (node-local). + if groupAddress == header.IPv6AllNodesMulticastAddress { + return false + } + + scope := header.V6MulticastScope(groupAddress) + return scope != header.IPv6Reserved0MulticastScope && scope != header.IPv6InterfaceLocalMulticastScope +} + // init sets up an mldState struct, and is required to be called before using // a new mldState. // @@ -91,7 +110,6 @@ func (mld *mldState) init(ep *endpoint) { Clock: ep.protocol.stack.Clock(), Protocol: mld, MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, - AllNodesAddress: header.IPv6AllNodesMulticastAddress, }) } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 85a8f9944..71d1c3e28 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -27,15 +27,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var ( + linkLocalAddr = testutil.MustParse6("fe80::1") + globalAddr = testutil.MustParse6("a80::1") + globalMulticastAddr = testutil.MustParse6("ff05:100::2") + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) ) @@ -93,7 +92,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { if p, ok := e.Read(); !ok { t.Fatal("expected a done message to be sent") } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } } @@ -354,10 +353,8 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, ho } func TestMLDPacketValidation(t *testing.T) { - const ( - nicID = 1 - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ) + const nicID = 1 + linkLocalAddr2 := testutil.MustParse6("fe80::2") tests := []struct { name string @@ -464,3 +461,141 @@ func TestMLDPacketValidation(t *testing.T) { }) } } + +func TestMLDSkipProtocol(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + group tcpip.Address + expectReport bool + }{ + { + name: "Reserverd0", + group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: false, + }, + { + name: "Interface Local", + group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: false, + }, + { + name: "Link Local", + group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Realm Local", + group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Admin Local", + group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Site Local", + group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(6)", + group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(7)", + group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Organization Local", + group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(9)", + group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(A)", + group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(B)", + group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(C)", + group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(D)", + group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Global", + group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "ReservedF", + group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil { + t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err) + } + if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group) + } + + if !test.expectReport { + if p, ok := e.Read(); ok { + t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p) + } + + return + } + + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 536493f87..a110faa54 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -737,7 +737,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { prefix := opt.Subnet() // Is the prefix a link-local? - if header.IsV6LinkLocalAddress(prefix.ID()) { + if header.IsV6LinkLocalUnicastAddress(prefix.ID()) { // ...Yes, skip as per RFC 4861 section 6.3.4, // and RFC 4862 section 5.5.3.b (for SLAAC). continue @@ -1703,7 +1703,7 @@ func (ndp *ndpState) startSolicitingRouters() { // the unspecified address if no address is assigned // to the sending interface. localAddr := header.IPv6Any - if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersLinkLocalMulticastAddress, false); addressEndpoint != nil { localAddr = addressEndpoint.AddressWithPrefix().Address addressEndpoint.DecRef() } @@ -1730,7 +1730,7 @@ func (ndp *ndpState) startSolicitingRouters() { icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpData, Src: localAddr, - Dst: header.IPv6AllRoutersMulticastAddress, + Dst: header.IPv6AllRoutersLinkLocalMulticastAddress, })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1739,14 +1739,14 @@ func (ndp *ndpState) startSolicitingRouters() { }) sent := ndp.ep.stats.icmp.packetsSent - if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } - if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.dropped.Increment() // Don't send any more messages if we had an error. remaining = 0 diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index ecd5003a7..1b96b1fb8 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -30,22 +30,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - stackIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") defaultIPv4PrefixLength = 24 - linkLocalIPv6Addr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalIPv6Addr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - - ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") - ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") - ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") - ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") - ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") igmpMembershipQuery = uint8(header.IGMPMembershipQuery) igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) @@ -59,6 +50,19 @@ const ( ) var ( + stackIPv4Addr = testutil.MustParse4("10.0.0.1") + linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1") + linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2") + + ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3") + ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4") + ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5") + ipv6MulticastAddr1 = testutil.MustParse6("ff02::3") + ipv6MulticastAddr2 = testutil.MustParse6("ff02::4") + ipv6MulticastAddr3 = testutil.MustParse6("ff02::5") +) + +var ( // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the // NIC will wait before sending an unsolicited report after joining a // multicast group, in deciseconds. @@ -194,7 +198,7 @@ func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, c if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") } else { - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC) } // Should not send any more packets. @@ -606,7 +610,7 @@ func TestMGPLeaveGroup(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, checkInitialGroups: checkInitialIPv6Groups, }, @@ -1014,7 +1018,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr) }, getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { t.Helper() diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 210262703..b7f6d52ae 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -21,6 +21,7 @@ go_test( library = ":ports", deps = [ "//pkg/tcpip", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 678199371..b5b013b64 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -17,6 +17,7 @@ package ports import ( + "math" "math/rand" "sync/atomic" @@ -24,7 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -const anyIPAddress tcpip.Address = "" +const ( + firstEphemeral = 16000 + anyIPAddress tcpip.Address = "" +) // Reservation describes a port reservation. type Reservation struct { @@ -220,10 +224,8 @@ type PortManager struct { func NewPortManager() *PortManager { return &PortManager{ allocatedPorts: make(map[portDescriptor]addrToDevice), - // Match Linux's default ephemeral range. See: - // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842 - firstEphemeral: 32768, - numEphemeral: 28232, + firstEphemeral: firstEphemeral, + numEphemeral: math.MaxUint16 - firstEphemeral + 1, } } @@ -242,13 +244,13 @@ func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err numEphemeral := pm.numEphemeral pm.ephemeralMu.RUnlock() - offset := uint16(rand.Int31n(int32(numEphemeral))) + offset := uint32(rand.Int31n(int32(numEphemeral))) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } // portHint atomically reads and returns the pm.hint value. -func (pm *PortManager) portHint() uint16 { - return uint16(atomic.LoadUint32(&pm.hint)) +func (pm *PortManager) portHint() uint32 { + return atomic.LoadUint32(&pm.hint) } // incPortHint atomically increments pm.hint by 1. @@ -260,7 +262,7 @@ func (pm *PortManager) incPortHint() { // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) { +func (pm *PortManager) PickEphemeralPortStable(offset uint32, testPort PortTester) (port uint16, err tcpip.Error) { pm.ephemeralMu.RLock() firstEphemeral := pm.firstEphemeral numEphemeral := pm.numEphemeral @@ -277,9 +279,9 @@ func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTeste // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { - for i := uint16(0); i < count; i++ { - port = first + (offset+i)%count +func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { + for i := uint32(0); i < uint32(count); i++ { + port := uint16(uint32(first) + (offset+i)%uint32(count)) ok, err := testPort(port) if err != nil { return 0, err diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 0f43dc8f8..6c4fb8c68 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -15,19 +15,23 @@ package ports import ( + "math" "math/rand" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 +) - fakeIPAddress = tcpip.Address("\x08\x08\x08\x08") - fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") +var ( + fakeIPAddress = testutil.MustParse4("8.8.8.8") + fakeIPAddress1 = testutil.MustParse4("8.8.8.9") ) type portReserveTestAction struct { @@ -479,7 +483,7 @@ func TestPickEphemeralPortStable(t *testing.T) { if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { t.Fatalf("failed to set ephemeral port range: %s", err) } - portOffset := uint16(rand.Int31n(int32(numEphemeralPorts))) + portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) port, err := pm.PickEphemeralPortStable(portOffset, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) @@ -490,3 +494,29 @@ func TestPickEphemeralPortStable(t *testing.T) { }) } } + +// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a +// port allocation failure. +func TestOverflow(t *testing.T) { + // Use a small range and start at offsets that will cause an overflow. + count := uint16(50) + for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ { + reservedPorts := make(map[uint16]struct{}) + // Ensure we can reserve everything in the allowed range. + for i := uint16(0); i < count; i++ { + port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) { + if _, ok := reservedPorts[port]; !ok { + reservedPorts[port] = struct{}{} + return true, nil + } + return false, nil + }) + if err != nil { + t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts)) + } + if port < firstEphemeral || port > firstEphemeral+count { + t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1) + } + } + } +} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index dc37e61a4..a6c877158 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -58,6 +58,9 @@ type SocketOptionsHandler interface { // changed. The handler is invoked with the new value for the socket send // buffer size. It also returns the newly set value. OnSetSendBufferSize(v int64) (newSz int64) + + // OnSetReceiveBufferSize is invoked to set the SO_RCVBUFSIZE. + OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -99,6 +102,11 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) { return v } +// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize. +func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) { + return v +} + // StackHandler holds methods to access the stack options. These must be // implemented by the stack. type StackHandler interface { @@ -207,6 +215,14 @@ type SocketOptions struct { // sendBufferSize determines the send buffer size for this socket. sendBufferSize int64 + // getReceiveBufferLimits provides the handler to get the min, default and + // max size for receive buffer. It is initialized at the creation time and + // will not change. + getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` + + // receiveBufferSize determines the receive buffer size for this socket. + receiveBufferSize int64 + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -217,10 +233,11 @@ type SocketOptions struct { // InitHandler initializes the handler. This must be called before using the // socket options utility. -func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) { +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits, getReceiveBufferLimits GetReceiveBufferLimits) { so.handler = handler so.stackHandler = stack so.getSendBufferLimits = getSendBufferLimits + so.getReceiveBufferLimits = getReceiveBufferLimits } func storeAtomicBool(addr *uint32, v bool) { @@ -632,3 +649,42 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { newSz := so.handler.OnSetSendBufferSize(v) atomic.StoreInt64(&so.sendBufferSize, newSz) } + +// GetReceiveBufferSize gets value for SO_RCVBUF option. +func (so *SocketOptions) GetReceiveBufferSize() int64 { + return atomic.LoadInt64(&so.receiveBufferSize) +} + +// SetReceiveBufferSize sets value for SO_RCVBUF option. +func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { + if !notify { + atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize) + return + } + + // Make sure the send buffer size is within the min and max + // allowed. + v := receiveBufferSize + ss := so.getReceiveBufferLimits(so.stackHandler) + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + if v > max { + v = max + } + + // Multiply it by factor of 2. + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min + } + } else { + v = math.MaxInt32 + } + + oldSz := atomic.LoadInt64(&so.receiveBufferSize) + // Notify endpoint about change in buffer size. + newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) + atomic.StoreInt64(&so.receiveBufferSize, newSz) +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 49362333a..2bd6a67f5 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -45,6 +45,7 @@ go_library( "addressable_endpoint_state.go", "conntrack.go", "headertype_string.go", + "hook_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -66,6 +67,7 @@ go_library( "stack.go", "stack_global_state.go", "stack_options.go", + "tcp.go", "transport_demuxer.go", "tuple_list.go", ], @@ -115,6 +117,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -139,6 +142,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 3f083928f..41e964cf3 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -16,6 +16,7 @@ package stack import ( "encoding/binary" + "fmt" "sync" "time" @@ -29,7 +30,7 @@ import ( // The connection is created for a packet if it does not exist. Every // connection contains two tuples (original and reply). The tuples are // manipulated if there is a matching NAT rule. The packet is modified by -// looking at the tuples in the Prerouting and Output hooks. +// looking at the tuples in each hook. // // Currently, only TCP tracking is supported. @@ -46,12 +47,14 @@ const ( ) // Manipulation type for the connection. +// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and +// DNAT at the same time. type manipType int const ( manipNone manipType = iota - manipDstPrerouting - manipDstOutput + manipSource + manipDestination ) // tuple holds a connection's identifying and manipulating data in one @@ -108,6 +111,7 @@ type conn struct { reply tuple // manip indicates if the packet should be manipulated. It is immutable. + // TODO(gvisor.dev/issue/5696): Support updating manipulation type. manip manipType // tcbHook indicates if the packet is inbound or outbound to @@ -124,6 +128,18 @@ type conn struct { lastUsed time.Time `state:".(unixTime)"` } +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn +} + // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now time.Time) bool { const establishedTimeout = 5 * 24 * time.Hour @@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { }, nil } -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn -} - func (ct *ConnTrack) init() { ct.mu.Lock() defer ct.mu.Unlock() @@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1 return nil } - // Create a new connection and change the port as per the iptables - // rule. This tuple will be used to manipulate the packet in - // handlePacket. replyTID := tid.reply() replyTID.srcAddr = address replyTID.srcPort = port - var manip manipType - switch hook { - case Prerouting: - manip = manipDstPrerouting - case Output: - manip = manipDstOutput + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil } - conn := newConn(tid, replyTID, manip, hook) + conn = newConn(tid, replyTID, manipDestination, hook) + ct.insertConn(conn) + return conn +} + +func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Input && hook != Postrouting { + return nil + } + + replyTID := tid.reply() + replyTID.dstAddr = address + replyTID.dstPort = port + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil + } + conn = newConn(tid, replyTID, manipSource, hook) ct.insertConn(conn) return conn } @@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Now that we hold the locks, ensure the tuple hasn't been inserted by // another thread. + // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? alreadyInserted := false for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { if other.tupleID == conn.original.tupleID { @@ -343,86 +369,6 @@ func (ct *ConnTrack) insertConn(conn *conn) { } } -// handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. -func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For prerouting redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. - switch dir { - case dirOriginal: - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - case dirReply: - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated - // on inbound packets, so we don't recalculate them. However, we should - // support cases when they are validated, e.g. when we can't offload - // receive checksumming. - - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - -// handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For output redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. For prerouting redirection, we only reach this point - // when replying, so packet sources are modified. - if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - } else { - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - if gso != nil && gso.NeedsCsum { - tcpHeader.SetChecksum(xsum) - } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) - } - - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - // handlePacket will manipulate the port and address of the packet if the // connection exists. Returns whether, after the packet traverses the tables, // it should create a new entry in the table. @@ -431,7 +377,9 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou return false } - if hook != Prerouting && hook != Output { + switch hook { + case Prerouting, Input, Output, Postrouting: + default: return false } @@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } conn, dir := ct.connFor(pkt) - // Connection or Rule not found for the packet. + // Connection not found for the packet. if conn == nil { - return true + // If this is the last hook in the data path for this packet (Input if + // incoming, Postrouting if outgoing), indicate that a connection should be + // inserted by the end of this hook. + return hook == Input || hook == Postrouting } + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return false } + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be + // validated if checksum offloading is off. It may require IP defrag if the + // packets are fragmented. + + switch hook { + case Prerouting, Output: + if conn.manip == manipDestination { + switch dir { + case dirOriginal: + tcpHeader.SetDestinationPort(conn.reply.srcPort) + netHeader.SetDestinationAddress(conn.reply.srcAddr) + case dirReply: + tcpHeader.SetSourcePort(conn.original.dstPort) + netHeader.SetSourceAddress(conn.original.dstAddr) + } + pkt.NatDone = true + } + case Input, Postrouting: + if conn.manip == manipSource { + switch dir { + case dirOriginal: + tcpHeader.SetSourcePort(conn.reply.dstPort) + netHeader.SetSourceAddress(conn.reply.dstAddr) + case dirReply: + tcpHeader.SetDestinationPort(conn.original.srcPort) + netHeader.SetDestinationAddress(conn.original.srcAddr) + } + pkt.NatDone = true + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + if !pkt.NatDone { + return false + } + switch hook { - case Prerouting: - handlePacketPrerouting(pkt, conn, dir) - case Output: - handlePacketOutput(pkt, conn, gso, r, dir) + case Prerouting, Input: + case Output, Postrouting: + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + length := uint16(len(tcpHeader) + pkt.Data().Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + if gso != nil && gso.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) } - pkt.NatDone = true // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle @@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ if conn == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip == manipNone { - // Unmanipulated connection. + } else if conn.manip != manipDestination { + // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } diff --git a/pkg/tcpip/stack/hook_string.go b/pkg/tcpip/stack/hook_string.go new file mode 100644 index 000000000..3dc8a7b02 --- /dev/null +++ b/pkg/tcpip/stack/hook_string.go @@ -0,0 +1,41 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type Hook ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Prerouting-0] + _ = x[Input-1] + _ = x[Forward-2] + _ = x[Output-3] + _ = x[Postrouting-4] + _ = x[NumHooks-5] +} + +const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks" + +var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47} + +func (i Hook) String() string { + if i >= Hook(len(_Hook_index)-1) { + return "Hook(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Hook_name[_Hook_index[i]:_Hook_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 52890f6eb..7ea87d325 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -175,9 +175,10 @@ func DefaultTables() *IPTables { }, }, priorities: [NumHooks][]TableID{ - Prerouting: {MangleID, NATID}, - Input: {NATID, FilterID}, - Output: {MangleID, NATID, FilterID}, + Prerouting: {MangleID, NATID}, + Input: {NATID, FilterID}, + Output: {MangleID, NATID, FilterID}, + Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ seed: generateRandUint32(), diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 0e8b90c9b..317efe754 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -182,3 +182,81 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs return RuleAccept, 0 } + +// SNATTarget modifies the source port/IP in the outgoing packets. +type SNATTarget struct { + Addr tcpip.Address + Port uint16 + + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if st.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + st.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + // Packet is already manipulated. + if pkt.NatDone { + return RuleAccept, 0 + } + + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { + return RuleDrop, 0 + } + + switch hook { + case Postrouting, Input: + case Prerouting, Output, Forward: + panic(fmt.Sprintf("%s not supported", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + udpHeader := header.UDP(pkt.TransportHeader().View()) + udpHeader.SetChecksum(0) + udpHeader.SetSourcePort(st.Port) + netHeader := pkt.Network() + netHeader.SetSourceAddress(st.Addr) + + // Only calculate the checksum if offloading isn't supported. + if r.RequiresTXTransportChecksum() { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } + pkt.NatDone = true + case header.TCPProtocolNumber: + if ct == nil { + return RuleAccept, 0 + } + + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { + ct.handlePacket(pkt, hook, gso, r) + } + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 14124ae66..b6cf24739 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -33,15 +33,19 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) +var ( + addr1 = testutil.MustParse6("a00::1") + addr2 = testutil.MustParse6("a00::2") + addr3 = testutil.MustParse6("a00::3") +) + const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") @@ -1390,7 +1394,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { // configured not to. func TestNoPrefixDiscovery(t *testing.T) { prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, } @@ -1590,7 +1594,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }() prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, } subnet := prefix.Subnet() @@ -5204,13 +5208,13 @@ func TestRouterSolicitation(t *testing.T) { } // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) @@ -5362,7 +5366,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS()) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index bb2b2d705..1d39ee73d 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -26,14 +26,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 entryTestNICID tcpip.NICID = 1 - entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") @@ -44,6 +43,11 @@ const ( entryTestNetDefaultMTU = 65536 ) +var ( + entryTestAddr1 = testutil.MustParse6("a::1") + entryTestAddr2 = testutil.MustParse6("a::2") +) + // runImmediatelyScheduledJobs runs all jobs scheduled to run at the current // time. func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 8f288675d..c10304d5f 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - return NewPacketBuffer(PacketBufferOptions{ + newPk := NewPacketBuffer(PacketBufferOptions{ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), }) + // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to + // maintain this flag in the packet. Currently conntrack needs this flag to + // tell if a noop connection should be inserted at Input hook. Once conntrack + // redefines the manipulation field as mutable, we won't need the special noop + // connection. + if pk.NatDone { + newPk.NatDone = true + } + return newPk } // headerInfo stores metadata about a header in a packet. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 39344808d..4ae6bed5a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp localAddr = addressEndpoint.AddressWithPrefix().Address } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) { addressEndpoint.DecRef() return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 931a97ddc..21cfbad71 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -35,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" ) @@ -56,306 +55,6 @@ type transportProtocolState struct { defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } -// TCPProbeFunc is the expected function type for a TCP probe function to be -// passed to stack.AddTCPProbe. -type TCPProbeFunc func(s TCPEndpointState) - -// TCPCubicState is used to hold a copy of the internal cubic state when the -// TCPProbeFunc is invoked. -type TCPCubicState struct { - WLastMax float64 - WMax float64 - T time.Time - TimeSinceLastCongestion time.Duration - C float64 - K float64 - Beta float64 - WC float64 - WEst float64 -} - -// TCPRACKState is used to hold a copy of the internal RACK state when the -// TCPProbeFunc is invoked. -type TCPRACKState struct { - XmitTime time.Time - EndSequence seqnum.Value - FACK seqnum.Value - RTT time.Duration - Reord bool - DSACKSeen bool - ReoWnd time.Duration - ReoWndIncr uint8 - ReoWndPersist int8 - RTTSeq seqnum.Value -} - -// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. -type TCPEndpointID struct { - // LocalPort is the local port associated with the endpoint. - LocalPort uint16 - - // LocalAddress is the local [network layer] address associated with - // the endpoint. - LocalAddress tcpip.Address - - // RemotePort is the remote port associated with the endpoint. - RemotePort uint16 - - // RemoteAddress it the remote [network layer] address associated with - // the endpoint. - RemoteAddress tcpip.Address -} - -// TCPFastRecoveryState holds a copy of the internal fast recovery state of a -// TCP endpoint. -type TCPFastRecoveryState struct { - // Active if true indicates the endpoint is in fast recovery. - Active bool - - // First is the first unacknowledged sequence number being recovered. - First seqnum.Value - - // Last is the 'recover' sequence number that indicates the point at - // which we should exit recovery barring any timeouts etc. - Last seqnum.Value - - // MaxCwnd is the maximum value we are permitted to grow the congestion - // window during recovery. This is set at the time we enter recovery. - MaxCwnd int - - // HighRxt is the highest sequence number which has been retransmitted - // during the current loss recovery phase. - // See: RFC 6675 Section 2 for details. - HighRxt seqnum.Value - - // RescueRxt is the highest sequence number which has been - // optimistically retransmitted to prevent stalling of the ACK clock - // when there is loss at the end of the window and no new data is - // available for transmission. - // See: RFC 6675 Section 2 for details. - RescueRxt seqnum.Value -} - -// TCPReceiverState holds a copy of the internal state of the receiver for -// a given TCP endpoint. -type TCPReceiverState struct { - // RcvNxt is the TCP variable RCV.NXT. - RcvNxt seqnum.Value - - // RcvAcc is the TCP variable RCV.ACC. - RcvAcc seqnum.Value - - // RcvWndScale is the window scaling to use for inbound segments. - RcvWndScale uint8 - - // PendingBufUsed is the number of bytes pending in the receive - // queue. - PendingBufUsed int -} - -// TCPSenderState holds a copy of the internal state of the sender for -// a given TCP Endpoint. -type TCPSenderState struct { - // LastSendTime is the time at which we sent the last segment. - LastSendTime time.Time - - // DupAckCount is the number of Duplicate ACK's received. - DupAckCount int - - // SndCwnd is the size of the sending congestion window in packets. - SndCwnd int - - // Ssthresh is the slow start threshold in packets. - Ssthresh int - - // SndCAAckCount is the number of packets consumed in congestion - // avoidance mode. - SndCAAckCount int - - // Outstanding is the number of packets in flight. - Outstanding int - - // SackedOut is the number of packets which have been selectively acked. - SackedOut int - - // SndWnd is the send window size in bytes. - SndWnd seqnum.Size - - // SndUna is the next unacknowledged sequence number. - SndUna seqnum.Value - - // SndNxt is the sequence number of the next segment to be sent. - SndNxt seqnum.Value - - // RTTMeasureSeqNum is the sequence number being used for the latest RTT - // measurement. - RTTMeasureSeqNum seqnum.Value - - // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time - - // Closed indicates that the caller has closed the endpoint for sending. - Closed bool - - // SRTT is the smoothed round-trip time as defined in section 2 of - // RFC 6298. - SRTT time.Duration - - // RTO is the retransmit timeout as defined in section of 2 of RFC 6298. - RTO time.Duration - - // RTTVar is the round-trip time variation as defined in section 2 of - // RFC 6298. - RTTVar time.Duration - - // SRTTInited if true indicates take a valid RTT measurement has been - // completed. - SRTTInited bool - - // MaxPayloadSize is the maximum size of the payload of a given segment. - // It is initialized on demand. - MaxPayloadSize int - - // SndWndScale is the number of bits to shift left when reading the send - // window size from a segment. - SndWndScale uint8 - - // MaxSentAck is the highest acknowledgement number sent till now. - MaxSentAck seqnum.Value - - // FastRecovery holds the fast recovery state for the endpoint. - FastRecovery TCPFastRecoveryState - - // Cubic holds the state related to CUBIC congestion control. - Cubic TCPCubicState - - // RACKState holds the state related to RACK loss detection algorithm. - RACKState TCPRACKState -} - -// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. -type TCPSACKInfo struct { - // Blocks is the list of SACK Blocks that identify the out of order segments - // held by a given TCP endpoint. - Blocks []header.SACKBlock - - // ReceivedBlocks are the SACK blocks received by this endpoint - // from the peer endpoint. - ReceivedBlocks []header.SACKBlock - - // MaxSACKED is the highest sequence number that has been SACKED - // by the peer. - MaxSACKED seqnum.Value -} - -// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning. -type RcvBufAutoTuneParams struct { - // MeasureTime is the time at which the current measurement - // was started. - MeasureTime time.Time - - // CopiedBytes is the number of bytes copied to user space since - // this measure began. - CopiedBytes int - - // PrevCopiedBytes is the number of bytes copied to userspace in - // the previous RTT period. - PrevCopiedBytes int - - // RcvBufSize is the auto tuned receive buffer size. - RcvBufSize int - - // RTT is the smoothed RTT as measured by observing the time between - // when a byte is first acknowledged and the receipt of data that is at - // least one window beyond the sequence number that was acknowledged. - RTT time.Duration - - // RTTVar is the "round-trip time variation" as defined in section 2 - // of RFC6298. - RTTVar time.Duration - - // RTTMeasureSeqNumber is the highest acceptable sequence number at the - // time this RTT measurement period began. - RTTMeasureSeqNumber seqnum.Value - - // RTTMeasureTime is the absolute time at which the current RTT - // measurement period began. - RTTMeasureTime time.Time - - // Disabled is true if an explicit receive buffer is set for the - // endpoint. - Disabled bool -} - -// TCPEndpointState is a copy of the internal state of a TCP endpoint. -type TCPEndpointState struct { - // ID is a copy of the TransportEndpointID for the endpoint. - ID TCPEndpointID - - // SegTime denotes the absolute time when this segment was received. - SegTime time.Time - - // RcvBufSize is the size of the receive socket buffer for the endpoint. - RcvBufSize int - - // RcvBufUsed is the amount of bytes actually held in the receive socket - // buffer for the endpoint. - RcvBufUsed int - - // RcvBufAutoTuneParams is used to hold state variables to compute - // the auto tuned receive buffer size. - RcvAutoParams RcvBufAutoTuneParams - - // RcvClosed if true, indicates the endpoint has been closed for reading. - RcvClosed bool - - // SendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1. - SendTSOk bool - - // RecentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - RecentTS uint32 - - // TSOffset is a randomized offset added to the value of the TSVal field - // in the timestamp option. - TSOffset uint32 - - // SACKPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - SACKPermitted bool - - // SACK holds TCP SACK related information for this endpoint. - SACK TCPSACKInfo - - // SndBufSize is the size of the socket send buffer. - SndBufSize int - - // SndBufUsed is the number of bytes held in the socket send buffer. - SndBufUsed int - - // SndClosed indicates that the endpoint has been closed for sends. - SndClosed bool - - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - - // PacketTooBigCount is used to notify the main protocol routine how - // many times a "packet too big" control packet is received. - PacketTooBigCount int - - // SndMTU is the smallest MTU seen in the control packets received. - SndMTU int - - // Receiver holds variables related to the TCP receiver for the endpoint. - Receiver TCPReceiverState - - // Sender holds state related to the TCP Sender for the endpoint. - Sender TCPSenderState -} - // ResumableEndpoint is an endpoint that needs to be resumed after restore. type ResumableEndpoint interface { // Resume resumes an endpoint after restore. This can be used to restart @@ -455,7 +154,7 @@ type Stack struct { // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. - receiveBufferSize ReceiveBufferSizeOption + receiveBufferSize tcpip.ReceiveBufferSizeOption // tcpInvalidRateLimit is the maximal rate for sending duplicate // acknowledgements in response to incoming TCP packets that are for an existing @@ -669,7 +368,7 @@ func New(opts Options) *Stack { Default: DefaultBufferSize, Max: DefaultMaxBufferSize, }, - receiveBufferSize: ReceiveBufferSizeOption{ + receiveBufferSize: tcpip.ReceiveBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, @@ -1344,7 +1043,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) + isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) @@ -1381,7 +1080,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index dfec4258a..33824afd0 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,6 +14,78 @@ package stack +import "time" + // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack + +// saveT is invoked by stateify. +func (t *TCPCubicState) saveT() unixTime { + return unixTime{t.T.Unix(), t.T.UnixNano()} +} + +// loadT is invoked by stateify. +func (t *TCPCubicState) loadT(unix unixTime) { + t.T = time.Unix(unix.second, unix.nano) +} + +// saveXmitTime is invoked by stateify. +func (t *TCPRACKState) saveXmitTime() unixTime { + return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (t *TCPRACKState) loadXmitTime(unix unixTime) { + t.XmitTime = time.Unix(unix.second, unix.nano) +} + +// saveLastSendTime is invoked by stateify. +func (t *TCPSenderState) saveLastSendTime() unixTime { + return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} +} + +// loadLastSendTime is invoked by stateify. +func (t *TCPSenderState) loadLastSendTime(unix unixTime) { + t.LastSendTime = time.Unix(unix.second, unix.nano) +} + +// saveRTTMeasureTime is invoked by stateify. +func (t *TCPSenderState) saveRTTMeasureTime() unixTime { + return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} +} + +// loadRTTMeasureTime is invoked by stateify. +func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { + t.RTTMeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { + return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} +} + +// loadMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { + r.MeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveRTTMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { + return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} +} + +// loadRTTMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { + r.RTTMeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveSegTime is invoked by stateify. +func (t *TCPEndpointState) saveSegTime() unixTime { + return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} +} + +// loadSegTime is invoked by stateify. +func (t *TCPEndpointState) loadSegTime(unix unixTime) { + t.SegTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 3066f4ffd..80e8e0089 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -68,7 +68,7 @@ func (s *Stack) SetOption(option interface{}) tcpip.Error { s.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { @@ -107,7 +107,7 @@ func (s *Stack) Option(option interface{}) tcpip.Error { s.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.ReceiveBufferSizeOption: s.mu.RLock() *v = s.receiveBufferSize s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 2814b94b4..a0bd69d9a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -39,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) @@ -1645,10 +1646,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} - nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} - nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") + nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. s := stack.New(stack.Options{ @@ -2789,25 +2790,27 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") - ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") - toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 lifetimeSeconds = 9999 ) + var ( + linkLocalAddr1 = testutil.MustParse6("fe80::1") + linkLocalAddr2 = testutil.MustParse6("fe80::2") + linkLocalMulticastAddr = testutil.MustParse6("ff02::1") + uniqueLocalAddr1 = testutil.MustParse6("fc00::1") + uniqueLocalAddr2 = testutil.MustParse6("fd00::2") + globalAddr1 = testutil.MustParse6("a000::1") + globalAddr2 = testutil.MustParse6("a000::2") + globalAddr3 = testutil.MustParse6("a000::3") + ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1") + ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2") + toredoAddr1 = testutil.MustParse6("2001::1") + toredoAddr2 = testutil.MustParse6("2001::2") + ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1") + ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2") + ) + prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) @@ -3354,21 +3357,21 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - rs stack.ReceiveBufferSizeOption + rs tcpip.ReceiveBufferSizeOption err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations - {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -3377,7 +3380,7 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { if err := s.SetOption(tc.rs); err != tc.err { t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if tc.err == nil { if err := s.Option(&rs); err != nil { t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) @@ -3448,7 +3451,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } ipv4Subnet := ipv4Addr.Subnet() ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4Gateway := testutil.MustParse4("192.168.1.1") ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ Address: "\xc0\xa8\x01\x3a", PrefixLen: 31, @@ -4352,13 +4355,15 @@ func TestWritePacketToRemote(t *testing.T) { func TestClearNeighborCacheOnNICDisable(t *testing.T) { const ( - nicID = 1 - - ipv4Addr = tcpip.Address("\x01\x02\x03\x04") - ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04") + nicID = 1 linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + var ( + ipv4Addr = testutil.MustParse4("1.2.3.4") + ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304") + ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go new file mode 100644 index 000000000..ddff6e2d6 --- /dev/null +++ b/pkg/tcpip/stack/tcp.go @@ -0,0 +1,451 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// TCPProbeFunc is the expected function type for a TCP probe function to be +// passed to stack.AddTCPProbe. +type TCPProbeFunc func(s TCPEndpointState) + +// TCPCubicState is used to hold a copy of the internal cubic state when the +// TCPProbeFunc is invoked. +// +// +stateify savable +type TCPCubicState struct { + // WLastMax is the previous wMax value. + WLastMax float64 + + // WMax is the value of the congestion window at the time of the last + // congestion event. + WMax float64 + + // T is the time when the current congestion avoidance was entered. + T time.Time `state:".(unixTime)"` + + // TimeSinceLastCongestion denotes the time since the current + // congestion avoidance was entered. + TimeSinceLastCongestion time.Duration + + // C is the cubic constant as specified in RFC8312, page 11. + C float64 + + // K is the time period (in seconds) that the above function takes to + // increase the current window size to WMax if there are no further + // congestion events and is calculated using the following equation: + // + // K = cubic_root(WMax*(1-beta_cubic)/C) (Eq. 2, page 5) + K float64 + + // Beta is the CUBIC multiplication decrease factor. That is, when a + // congestion event is detected, CUBIC reduces its cwnd to + // WC(0)=WMax*beta_cubic. + Beta float64 + + // WC is window computed by CUBIC at time TimeSinceLastCongestion. It's + // calculated using the formula: + // + // WC(TimeSinceLastCongestion) = C*(t-K)^3 + WMax (Eq. 1) + WC float64 + + // WEst is the window computed by CUBIC at time + // TimeSinceLastCongestion+RTT i.e WC(TimeSinceLastCongestion+RTT). + WEst float64 +} + +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +// +// +stateify savable +type TCPRACKState struct { + // XmitTime is the transmission timestamp of the most recent + // acknowledged segment. + XmitTime time.Time `state:".(unixTime)"` + + // EndSequence is the ending TCP sequence number of the most recent + // acknowledged segment. + EndSequence seqnum.Value + + // FACK is the highest selectively or cumulatively acknowledged + // sequence. + FACK seqnum.Value + + // RTT is the round trip time of the most recently delivered packet on + // the connection (either cumulatively acknowledged or selectively + // acknowledged) that was not marked invalid as a possible spurious + // retransmission. + RTT time.Duration + + // Reord is true iff reordering has been detected on this connection. + Reord bool + + // DSACKSeen is true iff the connection has seen a DSACK. + DSACKSeen bool + + // ReoWnd is the reordering window time used for recording packet + // transmission times. It is used to defer the moment at which RACK + // marks a packet lost. + ReoWnd time.Duration + + // ReoWndIncr is the multiplier applied to adjust reorder window. + ReoWndIncr uint8 + + // ReoWndPersist is the number of loss recoveries before resetting + // reorder window. + ReoWndPersist int8 + + // RTTSeq is the SND.NXT when RTT is updated. + RTTSeq seqnum.Value +} + +// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. +// +// +stateify savable +type TCPEndpointID struct { + // LocalPort is the local port associated with the endpoint. + LocalPort uint16 + + // LocalAddress is the local [network layer] address associated with + // the endpoint. + LocalAddress tcpip.Address + + // RemotePort is the remote port associated with the endpoint. + RemotePort uint16 + + // RemoteAddress it the remote [network layer] address associated with + // the endpoint. + RemoteAddress tcpip.Address +} + +// TCPFastRecoveryState holds a copy of the internal fast recovery state of a +// TCP endpoint. +// +// +stateify savable +type TCPFastRecoveryState struct { + // Active if true indicates the endpoint is in fast recovery. The + // following fields are only meaningful when Active is true. + Active bool + + // First is the first unacknowledged sequence number being recovered. + First seqnum.Value + + // Last is the 'recover' sequence number that indicates the point at + // which we should exit recovery barring any timeouts etc. + Last seqnum.Value + + // MaxCwnd is the maximum value we are permitted to grow the congestion + // window during recovery. This is set at the time we enter recovery. + // It exists to avoid attacks where the receiver intentionally sends + // duplicate acks to artificially inflate the sender's cwnd. + MaxCwnd int + + // HighRxt is the highest sequence number which has been retransmitted + // during the current loss recovery phase. See: RFC 6675 Section 2 for + // details. + HighRxt seqnum.Value + + // RescueRxt is the highest sequence number which has been + // optimistically retransmitted to prevent stalling of the ACK clock + // when there is loss at the end of the window and no new data is + // available for transmission. See: RFC 6675 Section 2 for details. + RescueRxt seqnum.Value +} + +// TCPReceiverState holds a copy of the internal state of the receiver for a +// given TCP endpoint. +// +// +stateify savable +type TCPReceiverState struct { + // RcvNxt is the TCP variable RCV.NXT. + RcvNxt seqnum.Value + + // RcvAcc is one beyond the last acceptable sequence number. That is, + // the "largest" sequence value that the receiver has announced to its + // peer that it's willing to accept. This may be different than RcvNxt + // + (last advertised receive window) if the receive window is reduced; + // in that case we have to reduce the window as we receive more data + // instead of shrinking it. + RcvAcc seqnum.Value + + // RcvWndScale is the window scaling to use for inbound segments. + RcvWndScale uint8 + + // PendingBufUsed is the number of bytes pending in the receive queue. + PendingBufUsed int +} + +// TCPRTTState holds a copy of information about the endpoint's round trip +// time. +// +// +stateify savable +type TCPRTTState struct { + // SRTT is the smoothed round trip time defined in section 2 of RFC + // 6298. + SRTT time.Duration + + // RTTVar is the round-trip time variation as defined in section 2 of + // RFC 6298. + RTTVar time.Duration + + // SRTTInited if true indicates that a valid RTT measurement has been + // completed. + SRTTInited bool +} + +// TCPSenderState holds a copy of the internal state of the sender for a given +// TCP Endpoint. +// +// +stateify savable +type TCPSenderState struct { + // LastSendTime is the timestamp at which we sent the last segment. + LastSendTime time.Time `state:".(unixTime)"` + + // DupAckCount is the number of Duplicate ACKs received. It is used for + // fast retransmit. + DupAckCount int + + // SndCwnd is the size of the sending congestion window in packets. + SndCwnd int + + // Ssthresh is the threshold between slow start and congestion + // avoidance. + Ssthresh int + + // SndCAAckCount is the number of packets acknowledged during + // congestion avoidance. When enough packets have been ack'd (typically + // cwnd packets), the congestion window is incremented by one. + SndCAAckCount int + + // Outstanding is the number of packets that have been sent but not yet + // acknowledged. + Outstanding int + + // SackedOut is the number of packets which have been selectively + // acked. + SackedOut int + + // SndWnd is the send window size in bytes. + SndWnd seqnum.Size + + // SndUna is the next unacknowledged sequence number. + SndUna seqnum.Value + + // SndNxt is the sequence number of the next segment to be sent. + SndNxt seqnum.Value + + // RTTMeasureSeqNum is the sequence number being used for the latest + // RTT measurement. + RTTMeasureSeqNum seqnum.Value + + // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. + RTTMeasureTime time.Time `state:".(unixTime)"` + + // Closed indicates that the caller has closed the endpoint for + // sending. + Closed bool + + // RTO is the retransmit timeout as defined in section of 2 of RFC + // 6298. + RTO time.Duration + + // RTTState holds information about the endpoint's round trip time. + RTTState TCPRTTState + + // MaxPayloadSize is the maximum size of the payload of a given + // segment. It is initialized on demand. + MaxPayloadSize int + + // SndWndScale is the number of bits to shift left when reading the + // send window size from a segment. + SndWndScale uint8 + + // MaxSentAck is the highest acknowledgement number sent till now. + MaxSentAck seqnum.Value + + // FastRecovery holds the fast recovery state for the endpoint. + FastRecovery TCPFastRecoveryState + + // Cubic holds the state related to CUBIC congestion control. + Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState +} + +// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. +// +// +stateify savable +type TCPSACKInfo struct { + // Blocks is the list of SACK Blocks that identify the out of order + // segments held by a given TCP endpoint. + Blocks []header.SACKBlock + + // ReceivedBlocks are the SACK blocks received by this endpoint from + // the peer endpoint. + ReceivedBlocks []header.SACKBlock + + // MaxSACKED is the highest sequence number that has been SACKED by the + // peer. + MaxSACKED seqnum.Value +} + +// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning. +// +// +stateify savable +type RcvBufAutoTuneParams struct { + // MeasureTime is the time at which the current measurement was + // started. + MeasureTime time.Time `state:".(unixTime)"` + + // CopiedBytes is the number of bytes copied to user space since this + // measure began. + CopiedBytes int + + // PrevCopiedBytes is the number of bytes copied to userspace in the + // previous RTT period. + PrevCopiedBytes int + + // RcvBufSize is the auto tuned receive buffer size. + RcvBufSize int + + // RTT is the smoothed RTT as measured by observing the time between + // when a byte is first acknowledged and the receipt of data that is at + // least one window beyond the sequence number that was acknowledged. + RTT time.Duration + + // RTTVar is the "round-trip time variation" as defined in section 2 of + // RFC6298. + RTTVar time.Duration + + // RTTMeasureSeqNumber is the highest acceptable sequence number at the + // time this RTT measurement period began. + RTTMeasureSeqNumber seqnum.Value + + // RTTMeasureTime is the absolute time at which the current RTT + // measurement period began. + RTTMeasureTime time.Time `state:".(unixTime)"` + + // Disabled is true if an explicit receive buffer is set for the + // endpoint. + Disabled bool +} + +// TCPRcvBufState contains information about the state of an endpoint's receive +// socket buffer. +// +// +stateify savable +type TCPRcvBufState struct { + // RcvBufUsed is the amount of bytes actually held in the receive + // socket buffer for the endpoint. + RcvBufUsed int + + // RcvBufAutoTuneParams is used to hold state variables to compute the + // auto tuned receive buffer size. + RcvAutoParams RcvBufAutoTuneParams + + // RcvClosed if true, indicates the endpoint has been closed for + // reading. + RcvClosed bool +} + +// TCPSndBufState contains information about the state of an endpoint's send +// socket buffer. +// +// +stateify savable +type TCPSndBufState struct { + // SndBufSize is the size of the socket send buffer. + SndBufSize int + + // SndBufUsed is the number of bytes held in the socket send buffer. + SndBufUsed int + + // SndClosed indicates that the endpoint has been closed for sends. + SndClosed bool + + // SndBufInQueue is the number of bytes in the send queue. + SndBufInQueue seqnum.Size + + // PacketTooBigCount is used to notify the main protocol routine how + // many times a "packet too big" control packet is received. + PacketTooBigCount int + + // SndMTU is the smallest MTU seen in the control packets received. + SndMTU int +} + +// TCPEndpointStateInner contains the members of TCPEndpointState used directly +// (that is, not within another containing struct) within the endpoint's +// internal implementation. +// +// +stateify savable +type TCPEndpointStateInner struct { + // TSOffset is a randomized offset added to the value of the TSVal + // field in the timestamp option. + TSOffset uint32 + + // SACKPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + SACKPermitted bool + + // SendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1. + SendTSOk bool + + // RecentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field + // is updated if required when a new segment is received by this + // endpoint. + RecentTS uint32 +} + +// TCPEndpointState is a copy of the internal state of a TCP endpoint. +// +// +stateify savable +type TCPEndpointState struct { + // TCPEndpointStateInner contains the members of TCPEndpointState used + // by the endpoint's internal implementation. + TCPEndpointStateInner + + // ID is a copy of the TransportEndpointID for the endpoint. + ID TCPEndpointID + + // SegTime denotes the absolute time when this segment was received. + SegTime time.Time `state:".(unixTime)"` + + // RcvBufState contains information about the state of the endpoint's + // receive socket buffer. + RcvBufState TCPRcvBufState + + // SndBufState contains information about the state of the endpoint's + // send socket buffer. + SndBufState TCPSndBufState + + // SACK holds TCP SACK related information for this endpoint. + SACK TCPSACKInfo + + // Receiver holds variables related to the TCP receiver for the + // endpoint. + Receiver TCPReceiverState + + // Sender holds state related to the TCP Sender for the endpoint. + Sender TCPSenderState +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e188efccb..80ad1a9d4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { return eps } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { +// handlePacket is called by the stack when new packets arrive to this transport +// endpoint. It returns false if the packet could not be matched to any +// transport endpoint, true otherwise. +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return false } } @@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return true } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() - return + return true } transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. + return true } // handleError delivers an error to the transport endpoint identified by id. @@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, } return false } - ep.handlePacket(id, pkt) - return true + return ep.handlePacket(id, pkt) } // deliverRawPacket attempts to deliver the given packet and returns whether it diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 054cced0c..0adedd7c0 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -70,7 +70,7 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} - ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) return ep } @@ -233,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress(), route: route, } - ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) f.acceptQueue = append(f.acceptQueue, ep) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 87ea09a5e..0ba71b62e 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -691,10 +691,6 @@ const ( // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption - // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to - // specify the receive buffer size option. - ReceiveBufferSizeOption - // SendQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the output buffer should be returned. SendQueueSizeOption @@ -786,6 +782,13 @@ func (*TCPRecovery) isGettableTransportProtocolOption() {} func (*TCPRecovery) isSettableTransportProtocolOption() {} +// TCPAlwaysUseSynCookies indicates unconditional usage of syncookies. +type TCPAlwaysUseSynCookies bool + +func (*TCPAlwaysUseSynCookies) isGettableTransportProtocolOption() {} + +func (*TCPAlwaysUseSynCookies) isSettableTransportProtocolOption() {} + const ( // TCPRACKLossDetection indicates RACK is used for loss detection and // recovery. @@ -1020,19 +1023,6 @@ func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {} func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {} -// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify -// the number of endpoints that can be in SYN-RCVD state before the stack -// switches to using SYN cookies. -type TCPSynRcvdCountThresholdOption uint64 - -func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {} - -func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {} - -func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {} - -func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {} - // TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide // default for number of times SYN is retransmitted before aborting a connect. type TCPSynRetriesOption uint8 @@ -1150,6 +1140,19 @@ type SendBufferSizeOption struct { Max int } +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + // Min is the minimum size for send buffer. + Min int + + // Default is the default size for send buffer. + Default int + + // Max is the maximum size for send buffer. + Max int +} + // GetSendBufferLimits is used to get the send buffer size limits. type GetSendBufferLimits func(StackHandler) SendBufferSizeOption @@ -1162,6 +1165,18 @@ func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption { return ss } +// GetReceiveBufferLimits is used to get the send buffer size limits. +type GetReceiveBufferLimits func(StackHandler) ReceiveBufferSizeOption + +// GetStackReceiveBufferLimits is used to get default, min and max send buffer size. +func GetStackReceiveBufferLimits(so StackHandler) ReceiveBufferSizeOption { + var ss ReceiveBufferSizeOption + if err := so.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + return ss +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -1218,7 +1233,7 @@ func (s *StatCounter) Decrement() { } // Value returns the current value of the counter. -func (s *StatCounter) Value() uint64 { +func (s *StatCounter) Value(name ...string) uint64 { return atomic.LoadUint64(&s.count) } @@ -1562,6 +1577,10 @@ type IPStats struct { // chain. IPTablesOutputDropped *StatCounter + // IPTablesPostroutingDropped is the number of IP packets dropped in the + // Postrouting chain. + IPTablesPostroutingDropped *StatCounter + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out // of IPStats. // OptionTimestampReceived is the number of Timestamp options seen. @@ -1734,6 +1753,10 @@ type TCPStats struct { // ChecksumErrors is the number of segments dropped due to bad checksums. ChecksumErrors *StatCounter + + // FailedPortReservations is the number of times TCP failed to reserve + // a port. + FailedPortReservations *StatCounter } // UDPStats collects UDP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 3cc8c36f1..d4f7bb5ff 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -9,11 +9,14 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -78,6 +81,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", @@ -101,6 +105,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -123,6 +128,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index d10ae05c2..dbd279c94 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -21,11 +21,14 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -312,3 +315,194 @@ func TestForwarding(t *testing.T) { }) } } + +func TestMulticastForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + ttl = 64 + ) + + var ( + ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") + ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + + ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") + ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") + ) + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo))) + } + + v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest))) + } + + tests := []struct { + name string + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + expectForward bool + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 link-local multicast destination", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4LinkLocalMulticastAddr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 link-local source", + srcAddr: ipv4LinkLocalUnicastAddr, + dstAddr: utils.RemoteIPv4Addr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 link-local destination", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4LinkLocalUnicastAddr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 non-link-local unicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 non-link-local multicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4GlobalMulticastAddr, + rx: rxICMPv4EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + }, + }, + + { + name: "IPv6 link-local multicast destination", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6LinkLocalMulticastAddr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 link-local source", + srcAddr: ipv6LinkLocalUnicastAddr, + dstAddr: utils.RemoteIPv6Addr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 link-local destination", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6LinkLocalUnicastAddr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 non-link-local unicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 non-link-local multicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6GlobalMulticastAddr, + rx: rxICMPv6EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + } + + if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID2, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID2, + }, + }) + + test.rx(e1, test.srcAddr, test.dstAddr) + + p, ok := e2.Read() + if ok != test.expectForward { + t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward) + } + + if test.expectForward { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 2c538a43e..b04169751 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -510,25 +511,25 @@ func TestExternalLoopbackTraffic(t *testing.T) { nicID1 = 1 nicID2 = 2 - ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") - numPackets = 1 + ttl = 64 ) + ipv4Loopback := testutil.MustParse4("127.0.0.1") loopbackSourcedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address) + utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl) } loopbackSourcedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address) + utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl) } loopbackDestinedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback) + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl) } loopbackDestinedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback) + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl) } invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index c6a9c2393..2d0a6e6a7 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -43,12 +44,15 @@ const ( // to a multicast or broadcast address uses a unicast source address for the // reply. func TestPingMulticastBroadcast(t *testing.T) { - const nicID = 1 + const ( + nicID = 1 + ttl = 64 + ) tests := []struct { name string protoNum tcpip.NetworkProtocolNumber - rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address) + rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8) srcAddr tcpip.Address dstAddr tcpip.Address expectedSrc tcpip.Address @@ -136,7 +140,7 @@ func TestPingMulticastBroadcast(t *testing.T) { }, }) - test.rxICMP(e, test.srcAddr, test.dstAddr) + test.rxICMP(e, test.srcAddr, test.dstAddr, ttl) pkt, ok := e.Read() if !ok { t.Fatal("expected ICMP response") @@ -435,10 +439,10 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { // interested endpoints. func TestReuseAddrAndBroadcast(t *testing.T) { const ( - nicID = 1 - localPort = 9000 - loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") + nicID = 1 + localPort = 9000 ) + loopbackBroadcast := testutil.MustParse4("127.255.255.255") tests := []struct { name string diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 78244f4eb..ac3c703d4 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -40,13 +41,13 @@ import ( // This tests that a local route is created and packets do not leave the stack. func TestLocalPing(t *testing.T) { const ( - nicID = 1 - ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + nicID = 1 // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo // request/reply packets. icmpDataOffset = 8 ) + ipv4Loopback := testutil.MustParse4("127.0.0.1") channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index d1c9f3a94..8fd9be32b 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -48,10 +48,6 @@ const ( LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") ) -const ( - ttl = 255 -) - // Common IP addresses used by tests. var ( Ipv4Addr = tcpip.AddressWithPrefix{ @@ -322,7 +318,7 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. // RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on // the provided endpoint. -func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -347,7 +343,7 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { // RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on // the provided endpoint. -func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD new file mode 100644 index 000000000..472545a5d --- /dev/null +++ b/pkg/tcpip/testutil/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + testonly = True, + srcs = ["testutil.go"], + visibility = ["//visibility:public"], + deps = ["//pkg/tcpip"], +) + +go_test( + name = "testutil_test", + srcs = ["testutil_test.go"], + library = ":testutil", + deps = ["//pkg/tcpip"], +) diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go new file mode 100644 index 000000000..1aaed590f --- /dev/null +++ b/pkg/tcpip/testutil/testutil.go @@ -0,0 +1,43 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil provides helper functions for netstack unit tests. +package testutil + +import ( + "fmt" + "net" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address. +// Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes. +func MustParse4(addr string) tcpip.Address { + ip := net.ParseIP(addr).To4() + if ip == nil { + panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr)) + } + return tcpip.Address(ip) +} + +// MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing +// an IPv4 address will yield an IPv4-mapped IPv6 address. +func MustParse6(addr string) tcpip.Address { + ip := net.ParseIP(addr).To16() + if ip == nil { + panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr)) + } + return tcpip.Address(ip) +} diff --git a/pkg/tcpip/testutil/testutil_test.go b/pkg/tcpip/testutil/testutil_test.go new file mode 100644 index 000000000..6aad9585d --- /dev/null +++ b/pkg/tcpip/testutil/testutil_test.go @@ -0,0 +1,103 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// Who tests the testutils? + +func TestMustParse4(t *testing.T) { + tcs := []struct { + str string + addr tcpip.Address + shouldPanic bool + }{ + { + str: "127.0.0.1", + addr: "\x7f\x00\x00\x01", + }, { + str: "", + shouldPanic: true, + }, { + str: "fe80::1", + shouldPanic: true, + }, { + // In an ideal world this panics too, but net.IP + // doesn't distinguish between IPv4 and IPv4-mapped + // addresses. + str: "::ffff:0.0.0.1", + addr: "\x00\x00\x00\x01", + }, + } + + for _, tc := range tcs { + t.Run(tc.str, func(t *testing.T) { + if tc.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("panic expected, but did not occur") + } + }() + } + if got := MustParse4(tc.str); got != tc.addr { + t.Errorf("got MustParse4(%s) = %s, want = %s", tc.str, got, tc.addr) + } + }) + } +} + +func TestMustParse6(t *testing.T) { + tcs := []struct { + str string + addr tcpip.Address + shouldPanic bool + }{ + { + // In an ideal world this panics too, but net.IP + // doesn't distinguish between IPv4 and IPv4-mapped + // addresses. + str: "127.0.0.1", + addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x7f\x00\x00\x01", + }, { + str: "", + shouldPanic: true, + }, { + str: "fe80::1", + addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + }, { + str: "::ffff:0.0.0.1", + addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01", + }, + } + + for _, tc := range tcs { + t.Run(tc.str, func(t *testing.T) { + if tc.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("panic expected, but did not occur") + } + }() + } + if got := MustParse6(tc.str); got != tc.addr { + t.Errorf("got MustParse6(%s) = %s, want = %s", tc.str, got, tc.addr) + } + }) + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 50991c3c0..33ed78f54 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -63,12 +63,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList icmpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList icmpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -84,6 +83,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { @@ -93,19 +96,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - state: stateInitial, - uniqueID: s.UniqueID(), + waiterQueue: waiterQueue, + state: stateInitial, + uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } + var rs tcpip.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) + } return ep, nil } @@ -371,12 +378,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.rcvMu.Lock() v := int(e.ttl) @@ -774,7 +775,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -843,3 +845,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index a3c6db5a8..28a56a2d5 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -36,40 +36,21 @@ func (p *icmpPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) if e.state != stateBound && e.state != stateConnected { return diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 52ed9560c..496eca581 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -72,11 +72,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -91,6 +90,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a new packet endpoint. @@ -100,12 +103,12 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, }, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -113,9 +116,9 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - ep.rcvBufSizeMax = rs.Default + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { @@ -316,28 +319,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := ep.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - ep.rcvMu.Lock() - ep.rcvBufSizeMax = v - ep.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } func (ep *endpoint) LastError() tcpip.Error { @@ -374,12 +356,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - v := ep.rcvBufSizeMax - ep.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -397,7 +373,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, return } - if ep.rcvBufSize >= ep.rcvBufSizeMax { + rcvBufSize := ep.ops.GetReceiveBufferSize() + if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -513,3 +490,18 @@ func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (ep *endpoint) freeze() { + ep.mu.Lock() + ep.frozen = true + ep.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (ep *endpoint) thaw() { + ep.mu.Lock() + ep.frozen = false + ep.mu.Unlock() +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index ece662c0d..5bd860d20 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -38,33 +38,14 @@ func (p *packet) loadData(data buffer.VectorisedView) { // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - ep.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - ep.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max + ep.freeze() } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { + ep.thaw() ep.stack = stack.StackFromEnv - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e27a249cd..10453a42a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,7 +26,6 @@ package raw import ( - "fmt" "io" "gvisor.dev/gvisor/pkg/sync" @@ -69,11 +68,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList rawPacketList - rcvBufSize int - rcvBufSizeMax int `state:".(int)"` - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList rawPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -89,6 +87,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a raw endpoint for the given protocols. @@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - associated: associated, + waiterQueue: waiterQueue, + associated: associated, } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } // Unassociated endpoints are write-only and users call Write() with IP // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - e.rcvBufSizeMax = 0 + e.ops.SetReceiveBufferSize(0, false) e.waiterQueue = nil return e, nil } @@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - e.rcvMu.Lock() - e.rcvBufSizeMax = v - e.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() @@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 263ec5146..5d6f2709c 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // If the endpoint is connected, re-connect. if e.connected { diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index a69d6624d..48417f192 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -34,14 +34,12 @@ go_library( "connect.go", "connect_unsafe.go", "cubic.go", - "cubic_state.go", "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", "protocol.go", "rack.go", - "rack_state.go", "rcv.go", "rcv_state.go", "reno.go", @@ -107,6 +105,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/tcp/testing/context", "//pkg/test/testutil", "//pkg/waiter", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 025b134e2..d4bd4e80e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -23,7 +23,6 @@ import ( "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -51,11 +50,6 @@ const ( // timestamp and the current timestamp. If the difference is greater // than maxTSDiff, the cookie is expired. maxTSDiff = 2 - - // SynRcvdCountThreshold is the default global maximum number of - // connections that are allowed to be in SYN-RCVD state before TCP - // starts using SYN cookies to accept connections. - SynRcvdCountThreshold uint64 = 1000 ) var ( @@ -80,9 +74,6 @@ func encodeMSS(mss uint16) uint32 { type listenContext struct { stack *stack.Stack - // synRcvdCount is a reference to the stack level synRcvdCount. - synRcvdCount *synRcvdCounter - // rcvWnd is the receive window that is sent by this listening context // in the initial SYN-ACK. rcvWnd seqnum.Size @@ -138,14 +129,12 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } - p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) - if !ok { - panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) - } - l.synRcvdCount = p.SynRcvdCounter() - rand.Read(l.nonce[0][:]) - rand.Read(l.nonce[1][:]) + for i := range l.nonce { + if _, err := io.ReadFull(stk.SecureRNG(), l.nonce[i][:]); err != nil { + panic(err) + } + } return l } @@ -163,14 +152,17 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // Feed everything to the hasher. l.hasherMu.Lock() l.hasher.Reset() + + // Per hash.Hash.Writer: + // + // It never returns an error. l.hasher.Write(payload[:]) l.hasher.Write(l.nonce[nonceIndex][:]) - io.WriteString(l.hasher, string(id.LocalAddress)) - io.WriteString(l.hasher, string(id.RemoteAddress)) + l.hasher.Write([]byte(id.LocalAddress)) + l.hasher.Write([]byte(id.RemoteAddress)) // Finalize the calculation of the hash and return the first 4 bytes. - h := make([]byte, 0, sha1.Size) - h = l.hasher.Sum(h) + h := l.hasher.Sum(nil) l.hasherMu.Unlock() return binary.BigEndian.Uint32(h[:]) @@ -199,9 +191,17 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } +func (l *listenContext) useSynCookies() bool { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) +} + // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -215,11 +215,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n := newEndpoint(l.stack, netProto, queue) n.ops.SetV6Only(l.v6Only) - n.ID = s.id + n.TransportEndpointInfo.ID = s.id n.boundNICID = s.nicID n.route = route n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} - n.rcvBufSize = int(l.rcvWnd) + n.ops.SetReceiveBufferSize(int64(l.rcvWnd), false /* notify */) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -231,7 +231,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // Bootstrap the auto tuning algorithm. Starting at zero will result in // a large step function on the first window adjustment causing the // window to grow to a really large value. - n.rcvAutoParams.prevCopied = n.initialReceiveWindow() + n.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = n.initialReceiveWindow() return n, nil } @@ -248,7 +248,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err } @@ -290,7 +290,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q } // Register new endpoint so that packets are routed to it. - if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + if err := ep.stack.RegisterTransportEndpoint( + ep.effectiveNetProtos, + ProtocolNumber, + ep.TransportEndpointInfo.ID, + ep, + ep.boundPortFlags, + ep.boundBindToDevice, + ); err != nil { ep.mu.Unlock() ep.Close() @@ -307,6 +314,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Initialize and start the handshake. h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) + h.listenEP = l.listenEP h.start() return h, nil } @@ -334,14 +342,14 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.ID] = n + l.pendingEndpoints[n.TransportEndpointInfo.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.ID) + delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -382,39 +390,46 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window // scaling. - e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() + e.rcv.RcvWndScale = e.h.effectiveRcvWndScale() // Clean up handshake state stored in the endpoint so that it can be GCed. e.h = nil } // deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// endpoint has transitioned out of the listen state (acceptedChan is nil), -// the new endpoint is closed instead. +// listener has transitioned out of the listen state (accepted is the zero +// value), the new endpoint is reset instead. func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Lock() e.pendingAccepted.Add(1) e.mu.Unlock() defer e.pendingAccepted.Done() - e.acceptMu.Lock() - for { - if e.acceptedChan == nil { - e.acceptMu.Unlock() - n.notifyProtocolGoroutine(notifyReset) - return - } - select { - case e.acceptedChan <- n: + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + if e.accepted == (accepted{}) { + return false + } + if e.accepted.endpoints.Len() == e.accepted.cap { + e.acceptCond.Wait() + continue + } + + e.accepted.endpoints.PushBack(n) if !withSynCookie { atomic.AddInt32(&e.synRcvdCount, -1) } - e.acceptMu.Unlock() - e.waiterQueue.Notify(waiter.ReadableEvents) - return - default: - e.acceptCond.Wait() + return true } + }() + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + n.notifyProtocolGoroutine(notifyReset) } } @@ -436,17 +451,21 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // * propagateInheritableOptionsLocked has been called. // * e.mu is held. func (e *endpoint) reserveTupleLocked() bool { - dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + dest := tcpip.FullAddress{ + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, + } portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: dest, } if !e.stack.ReserveTuple(portRes) { + e.stack.Stats().TCP.FailedPortReservations.Increment() return false } @@ -485,7 +504,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } go func() { - defer ctx.synRcvdCount.dec() if err := h.complete(); err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -497,24 +515,29 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(h.ep, false /*withSynCookie*/) - }() // S/R-SAFE: synRcvdCount is the barrier. + }() return nil } -func (e *endpoint) incSynRcvdCount() bool { +func (e *endpoint) synRcvdBacklogFull() bool { e.acceptMu.Lock() - canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan) + acceptedCap := e.accepted.cap e.acceptMu.Unlock() - if canInc { - atomic.AddInt32(&e.synRcvdCount, 1) - } - return canInc + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + // + // We maintain an equality check here as the synRcvdCount is incremented + // and compared only from a single listener context and the capacity of + // the accepted queue can only increase by a new listen call. + return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan) + full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap e.acceptMu.Unlock() return full } @@ -524,9 +547,9 @@ func (e *endpoint) acceptQueueIsFull() bool { // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { - e.rcvListMu.Lock() - rcvClosed := e.rcvClosed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := e.rcvQueueInfo.RcvClosed + e.rcvQueueInfo.rcvQueueMu.Unlock() if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { // If the endpoint is shutdown, reply with reset. // @@ -538,69 +561,55 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err switch { case s.flags == header.TCPFlagSyn: - opts := parseSynSegmentOptions(s) - if ctx.synRcvdCount.inc() { - // Only handle the syn if the following conditions hold - // - accept queue is not full. - // - number of connections in synRcvd state is less than the - // backlog. - if !e.acceptQueueIsFull() && e.incSynRcvdCount() { - s.incRef() - _ = e.handleSynSegment(ctx, s, &opts) - return nil - } - ctx.synRcvdCount.dec() + if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return nil - } else { - // If cookies are in use but the endpoint accept queue - // is full then drop the syn. - if e.acceptQueueIsFull() { - e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() - e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() - e.stack.Stats().DroppedPackets.Increment() - return nil - } - cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + } - route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) - if err != nil { - return err - } - defer route.Release() + opts := parseSynSegmentOptions(s) + if !ctx.useSynCookies() { + s.incRef() + atomic.AddInt32(&e.synRcvdCount, 1) + return e.handleSynSegment(ctx, s, &opts) + } + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() - // Send SYN without window scaling because we currently - // don't encode this information in the cookie. - // - // Enable Timestamp option if the original syn did have - // the timestamp option specified. - // - // Use the user supplied MSS on the listening socket for - // new connections, if available. - synOpts := header.TCPSynOptions{ - WS: -1, - TS: opts.TS, - TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), - TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, route), - } - fields := tcpFields{ - id: s.id, - ttl: e.ttl, - tos: e.sendTOS, - flags: header.TCPFlagSyn | header.TCPFlagAck, - seq: cookie, - ack: s.sequenceNumber + 1, - rcvWnd: ctx.rcvWnd, - } - if err := e.sendSynTCP(route, fields, synOpts); err != nil { - return err - } - e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() - return nil + // Send SYN without window scaling because we currently + // don't encode this information in the cookie. + // + // Enable Timestamp option if the original syn did have + // the timestamp option specified. + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. + synOpts := header.TCPSynOptions{ + WS: -1, + TS: opts.TS, + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), + TSEcr: opts.TSVal, + MSS: calculateAdvertisedMSS(e.userMSS, route), + } + cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + fields := tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + } + if err := e.sendSynTCP(route, fields, synOpts); err != nil { + return err } + e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { @@ -615,25 +624,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } - if !ctx.synRcvdCount.synCookiesInUse() { - // When not using SYN cookies, as per RFC 793, section 3.9, page 64: - // Any acknowledgment is bad if it arrives on a connection still in - // the LISTEN state. An acceptable reset segment should be formed - // for any arriving ACK-bearing segment. The RST should be - // formatted as follows: - // - // <SEQ=SEG.ACK><CTL=RST> - // - // Send a reset as this is an ACK for which there is no - // half open connections and we are not using cookies - // yet. - // - // The only time we should reach here when a connection - // was opened and closed really quickly and a delayed - // ACK was received from the sender. - return replyWithReset(e.stack, s, e.sendTOS, e.ttl) - } - iss := s.ackNumber - 1 irs := s.sequenceNumber - 1 @@ -651,7 +641,23 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return nil + + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // + // Send a reset as this is an ACK for which there is no + // half open connections and we are not using cookies + // yet. + // + // The only time we should reach here when a connection + // was opened and closed really quickly and a delayed + // ACK was received from the sender. + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. @@ -672,7 +678,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { return err } @@ -693,7 +699,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint( + n.effectiveNetProtos, + ProtocolNumber, + n.TransportEndpointInfo.ID, + n, + n.boundPortFlags, + n.boundBindToDevice, + ); err != nil { n.mu.Unlock() n.Close() @@ -708,7 +721,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was // sent above. - n.tsOffset = 0 + n.TSOffset = 0 // Switch state to connected. n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a9e978cf6..7bc6b08f0 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -65,11 +65,12 @@ const ( // NOTE: handshake.ep.mu is held during handshake processing. It is released if // we are going to block and reacquired when we start processing an event. type handshake struct { - ep *endpoint - state handshakeState - active bool - flags header.TCPFlags - ackNum seqnum.Value + ep *endpoint + listenEP *endpoint + state handshakeState + active bool + flags header.TCPFlags + ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. iss seqnum.Value @@ -155,7 +156,7 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed()) } // generateSecureISN generates a secure Initial Sequence number based on the @@ -301,7 +302,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { ttl = h.ep.route.DefaultTTL() } h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -357,14 +358,14 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, - TS: h.ep.sendTSOk, + TS: h.ep.SendTSOk, TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), - SACKPermitted: h.ep.sackPermitted, + SACKPermitted: h.ep.SACKPermitted, MSS: h.ep.amss, } h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -389,13 +390,22 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. - if h.ep.sendTSOk && !s.parsedOptions.TS { + if h.ep.SendTSOk && !s.parsedOptions.TS { h.ep.stack.Stats().DroppedPackets.Increment() return nil } + // Drop the ACK if the accept queue is full. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_ipv4.c#L1523 + // We could abort the connection as well with a tunable as in + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_minisocks.c#L788 + if listenEP := h.listenEP; listenEP != nil && listenEP.acceptQueueIsFull() { + listenEP.stack.Stats().DroppedPackets.Increment() + return nil + } + // Update timestamp if required. See RFC7323, section-4.3. - if h.ep.sendTSOk && s.parsedOptions.TS { + if h.ep.SendTSOk && s.parsedOptions.TS { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted @@ -485,8 +495,8 @@ func (h *handshake) start() { // start() is also called in a listen context so we want to make sure we only // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { - synOpts.TS = h.ep.sendTSOk - synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + synOpts.TS = h.ep.SendTSOk + synOpts.SACKPermitted = h.ep.SACKPermitted && bool(sackEnabled) if h.sndWndScale < 0 { // Disable window scaling if the peer did not send us // the window scaling option. @@ -496,7 +506,7 @@ func (h *handshake) start() { h.sendSYNOpts = synOpts h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -544,7 +554,7 @@ func (h *handshake) complete() tcpip.Error { // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -845,7 +855,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // N.B. the ordering here matches the ordering used by Linux internally // and described in the raw makeOptions function. We don't include // unnecessary cases here (post connection.) - if e.sendTSOk { + if e.SendTSOk { // Embed the timestamp if timestamp has been enabled. // // We only use the lower 32 bits of the unix time in @@ -862,7 +872,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { offset += header.EncodeNOP(options[offset:]) offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } - if e.sackPermitted && len(sackBlocks) > 0 { + if e.SACKPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) @@ -884,7 +894,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se } options := e.makeOptions(sackBlocks) err := e.sendTCP(e.route, tcpFields{ - id: e.ID, + id: e.TransportEndpointInfo.ID, ttl: e.ttl, tos: e.sendTOS, flags: flags, @@ -898,9 +908,9 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se } func (e *endpoint) handleWrite() { - e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Lock() next := e.drainSendQueueLocked() - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() e.sendData(next) } @@ -909,10 +919,10 @@ func (e *endpoint) handleWrite() { // // Precondition: e.sndBufMu must be locked. func (e *endpoint) drainSendQueueLocked() *segment { - first := e.sndQueue.Front() + first := e.sndQueueInfo.sndQueue.Front() if first != nil { - e.snd.writeList.PushBackList(&e.sndQueue) - e.sndBufInQueue = 0 + e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue) + e.sndQueueInfo.SndBufInQueue = 0 } return first } @@ -936,7 +946,7 @@ func (e *endpoint) handleClose() { e.handleWrite() // Mark send side as closed. - e.snd.closed = true + e.snd.Closed = true } // resetConnectionLocked puts the endpoint in an error state with the given @@ -958,12 +968,12 @@ func (e *endpoint) resetConnectionLocked(err tcpip.Error) { // // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more // information. - sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + sndWndEnd := e.snd.SndUna.Add(e.snd.SndWnd) resetSeqNum := sndWndEnd - if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { - resetSeqNum = e.snd.sndNxt + if !sndWndEnd.LessThan(e.snd.SndNxt) || e.snd.SndNxt.Size(sndWndEnd) < (1<<e.snd.SndWndScale) { + resetSeqNum = e.snd.SndNxt } - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.RcvNxt, 0) } } @@ -989,13 +999,13 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. - e.rcvAutoParams.prevCopied = int(h.rcvWnd) - e.rcvListMu.Unlock() + e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd) + e.rcvQueueInfo.rcvQueueMu.Unlock() e.setEndpointState(StateEstablished) } @@ -1026,10 +1036,15 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.TransportEndpointInfo.ID, s.nicID) if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) + ep = e.stack.FindTransportEndpoint( + header.IPv4ProtocolNumber, + e.TransProto, + e.TransportEndpointInfo.ID, + s.nicID, + ) } if ep == nil { replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) @@ -1108,7 +1123,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { } // handleSegments processes all inbound segments. -func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { +// +// Precondition: e.mu must be held. +func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { if e.EndpointState().closed() { @@ -1120,7 +1137,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { break } - cont, err := e.handleSegment(s) + cont, err := e.handleSegmentLocked(s) s.decRef() if err != nil { return err @@ -1138,7 +1155,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { } // Send an ACK for all processed packets if needed. - if e.rcv.rcvNxt != e.snd.maxSentAck { + if e.rcv.RcvNxt != e.snd.MaxSentAck { e.snd.sendAck() } @@ -1147,18 +1164,21 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { return nil } -func (e *endpoint) probeSegment() { - if e.probe != nil { - e.probe(e.completeState()) +// Precondition: e.mu must be held. +func (e *endpoint) probeSegmentLocked() { + if fn := e.probe; fn != nil { + fn(e.completeStateLocked()) } } // handleSegment handles a given segment and notifies the worker goroutine if // if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { +// +// Precondition: e.mu must be held. +func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) { // Invoke the tcp probe if installed. The tcp probe function will update // the TCPEndpointState after the segment is processed. - defer e.probeSegment() + defer e.probeSegmentLocked() if s.flagIsSet(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { @@ -1191,7 +1211,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. - s.window <<= e.snd.sndWndScale + s.window <<= e.snd.SndWndScale // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment @@ -1255,7 +1275,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error { // seg.seq = snd.nxt-1. e.keepalive.unacked++ e.keepalive.Unlock() - e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1) + e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.SndNxt-1) e.resetKeepaliveTimer(false) return nil } @@ -1269,7 +1289,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } // Start the keepalive timer IFF it's enabled and there is no pending // data to send. - if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.SndUna != e.snd.SndNxt { e.keepalive.timer.disable() e.keepalive.Unlock() return @@ -1362,14 +1382,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ f func() tcpip.Error }{ { - w: &e.sndWaker, + w: &e.sndQueueInfo.sndWaker, f: func() tcpip.Error { e.handleWrite() return nil }, }, { - w: &e.sndCloseWaker, + w: &e.sndQueueInfo.sndCloseWaker, f: func() tcpip.Error { e.handleClose() return nil @@ -1403,7 +1423,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ { w: &e.newSegmentWaker, f: func() tcpip.Error { - return e.handleSegments(false /* fastPath */) + return e.handleSegmentsLocked(false /* fastPath */) }, }, { @@ -1419,11 +1439,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyMTUChanged != 0 { - e.sndBufMu.Lock() - count := e.packetTooBigCount - e.packetTooBigCount = 0 - mtu := e.sndMTU - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Lock() + count := e.sndQueueInfo.PacketTooBigCount + e.sndQueueInfo.PacketTooBigCount = 0 + mtu := e.sndQueueInfo.SndMTU + e.sndQueueInfo.sndQueueMu.Unlock() e.snd.updateMaxPayloadSize(mtu, count) } @@ -1453,7 +1473,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(false /* fastPath */); err != nil { + if err := e.handleSegmentsLocked(false /* fastPath */); err != nil { return err } } @@ -1504,11 +1524,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.newSegmentWaker.Assert() } - e.rcvListMu.Lock() - if !e.rcvList.Empty() { + e.rcvQueueInfo.rcvQueueMu.Lock() + if !e.rcvQueueInfo.rcvQueue.Empty() { e.waiterQueue.Notify(waiter.ReadableEvents) } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 1975f1a44..962f1d687 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -17,6 +17,8 @@ package tcp import ( "math" "time" + + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // cubicState stores the variables related to TCP CUBIC congestion @@ -25,47 +27,12 @@ import ( // See: https://tools.ietf.org/html/rfc8312. // +stateify savable type cubicState struct { - // wLastMax is the previous wMax value. - wLastMax float64 - - // wMax is the value of the congestion window at the - // time of last congestion event. - wMax float64 - - // t denotes the time when the current congestion avoidance - // was entered. - t time.Time `state:".(unixTime)"` + stack.TCPCubicState // numCongestionEvents tracks the number of congestion events since last // RTO. numCongestionEvents int - // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as - // per RFC. - c float64 - - // k is the time period that the above function takes to increase the - // current window size to W_max if there are no further congestion - // events and is calculated using the following equation: - // - // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2) - k float64 - - // beta is the CUBIC multiplication decrease factor. that is, when a - // congestion event is detected, CUBIC reduces its cwnd to - // W_cubic(0)=W_max*beta_cubic. - beta float64 - - // wC is window computed by CUBIC at time t. It's calculated using the - // formula: - // - // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) - wC float64 - - // wEst is the window computed by CUBIC at time t+RTT i.e - // W_cubic(t+RTT). - wEst float64 - s *sender } @@ -73,10 +40,12 @@ type cubicState struct { // beta and c set and t set to current time. func newCubicCC(s *sender) *cubicState { return &cubicState{ - t: time.Now(), - beta: 0.7, - c: 0.4, - s: s, + TCPCubicState: stack.TCPCubicState{ + T: time.Now(), + Beta: 0.7, + C: 0.4, + }, + s: s, } } @@ -90,10 +59,10 @@ func (c *cubicState) enterCongestionAvoidance() { // See: https://tools.ietf.org/html/rfc8312#section-4.7 & // https://tools.ietf.org/html/rfc8312#section-4.8 if c.numCongestionEvents == 0 { - c.k = 0 - c.t = time.Now() - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.K = 0 + c.T = time.Now() + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) } } @@ -104,16 +73,16 @@ func (c *cubicState) enterCongestionAvoidance() { func (c *cubicState) updateSlowStart(packetsAcked int) int { // Don't let the congestion window cross into the congestion // avoidance range. - newcwnd := c.s.sndCwnd + packetsAcked + newcwnd := c.s.SndCwnd + packetsAcked enterCA := false - if newcwnd >= c.s.sndSsthresh { - newcwnd = c.s.sndSsthresh - c.s.sndCAAckCount = 0 + if newcwnd >= c.s.Ssthresh { + newcwnd = c.s.Ssthresh + c.s.SndCAAckCount = 0 enterCA = true } - packetsAcked -= newcwnd - c.s.sndCwnd - c.s.sndCwnd = newcwnd + packetsAcked -= newcwnd - c.s.SndCwnd + c.s.SndCwnd = newcwnd if enterCA { c.enterCongestionAvoidance() } @@ -124,49 +93,49 @@ func (c *cubicState) updateSlowStart(packetsAcked int) int { // ACK received. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) Update(packetsAcked int) { - if c.s.sndCwnd < c.s.sndSsthresh { + if c.s.SndCwnd < c.s.Ssthresh { packetsAcked = c.updateSlowStart(packetsAcked) if packetsAcked == 0 { return } } else { c.s.rtt.Lock() - srtt := c.s.rtt.srtt + srtt := c.s.rtt.TCPRTTState.SRTT c.s.rtt.Unlock() - c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt) + c.s.SndCwnd = c.getCwnd(packetsAcked, c.s.SndCwnd, srtt) } } // cubicCwnd computes the CUBIC congestion window after t seconds from last // congestion event. func (c *cubicState) cubicCwnd(t float64) float64 { - return c.c*math.Pow(t, 3.0) + c.wMax + return c.C*math.Pow(t, 3.0) + c.WMax } // getCwnd returns the current congestion window as computed by CUBIC. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { - elapsed := time.Since(c.t).Seconds() + elapsed := time.Since(c.T).Seconds() // Compute the window as per Cubic after 'elapsed' time // since last congestion event. - c.wC = c.cubicCwnd(elapsed - c.k) + c.WC = c.cubicCwnd(elapsed - c.K) // Compute the TCP friendly estimate of the congestion window. - c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds()) + c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds()) // Make sure in the TCP friendly region CUBIC performs at least // as well as Reno. - if c.wC < c.wEst && float64(sndCwnd) < c.wEst { + if c.WC < c.WEst && float64(sndCwnd) < c.WEst { // TCP Friendly region of cubic. - return int(c.wEst) + return int(c.WEst) } // In Concave/Convex region of CUBIC, calculate what CUBIC window // will be after 1 RTT and use that to grow congestion window // for every ack. - tEst := (time.Since(c.t) + srtt).Seconds() - wtRtt := c.cubicCwnd(tEst - c.k) + tEst := (time.Since(c.T) + srtt).Seconds() + wtRtt := c.cubicCwnd(tEst - c.K) // As per 4.3 for each received ACK cwnd must be incremented // by (w_cubic(t+RTT) - cwnd/cwnd. cwnd := float64(sndCwnd) @@ -182,9 +151,9 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ - c.t = time.Now() - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.T = time.Now() + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) c.fastConvergence() c.reduceSlowStartThreshold() @@ -193,10 +162,10 @@ func (c *cubicState) HandleLossDetected() { // HandleRTOExpired implements congestionContrl.HandleRTOExpired. func (c *cubicState) HandleRTOExpired() { // See: https://tools.ietf.org/html/rfc8312#section-4.6 - c.t = time.Now() + c.T = time.Now() c.numCongestionEvents = 0 - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) c.fastConvergence() @@ -206,29 +175,29 @@ func (c *cubicState) HandleRTOExpired() { // Reduce the congestion window to 1, i.e., enter slow-start. Per // RFC 5681, page 7, we must use 1 regardless of the value of the // initial congestion window. - c.s.sndCwnd = 1 + c.s.SndCwnd = 1 } // fastConvergence implements the logic for Fast Convergence algorithm as // described in https://tools.ietf.org/html/rfc8312#section-4.6. func (c *cubicState) fastConvergence() { - if c.wMax < c.wLastMax { - c.wLastMax = c.wMax - c.wMax = c.wMax * (1.0 + c.beta) / 2.0 + if c.WMax < c.WLastMax { + c.WLastMax = c.WMax + c.WMax = c.WMax * (1.0 + c.Beta) / 2.0 } else { - c.wLastMax = c.wMax + c.WLastMax = c.WMax } // Recompute k as wMax may have changed. - c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c) + c.K = math.Cbrt(c.WMax * (1 - c.Beta) / c.C) } // PostRecovery implemements congestionControl.PostRecovery. func (c *cubicState) PostRecovery() { - c.t = time.Now() + c.T = time.Now() } // reduceSlowStartThreshold returns new SsThresh as described in // https://tools.ietf.org/html/rfc8312#section-4.7. func (c *cubicState) reduceSlowStartThreshold() { - c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0)) + c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0)) } diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 21162f01a..512053a04 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -116,7 +116,7 @@ func (p *processor) start(wg *sync.WaitGroup) { if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { // If the endpoint is in a connected state then we do direct delivery // to ensure low latency and avoid scheduler interactions. - switch err := ep.handleSegments(true /* fastPath */); { + switch err := ep.handleSegmentsLocked(true /* fastPath */); { case err != nil: // Send any active resets if required. ep.resetConnectionLocked(err) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f6a16f96e..d6d68f128 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -565,17 +565,15 @@ func TestV4AcceptOnV4(t *testing.T) { } func testV4ListenClose(t *testing.T, c *context.Context) { - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } - const n = uint16(32) + const n = 32 // Start listening. - if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { + if err := c.EP.Listen(n); err != nil { t.Fatalf("Listen failed: %v", err) } @@ -591,9 +589,9 @@ func testV4ListenClose(t *testing.T, c *context.Context) { }) } - // Each of these ACK's will cause a syn-cookie based connection to be + // Each of these ACKs will cause a syn-cookie based connection to be // accepted and delivered to the listening endpoint. - for i := uint16(0); i < n; i++ { + for i := 0; i < n; i++ { b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss := seqnum.Value(tcp.SequenceNumber()) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c5daba232..f25dc781a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "container/list" "encoding/binary" "fmt" "io" @@ -190,42 +191,6 @@ type SACKInfo struct { NumBlocks int } -// rcvBufAutoTuneParams are used to hold state variables to compute -// the auto tuned recv buffer size. -// -// +stateify savable -type rcvBufAutoTuneParams struct { - // measureTime is the time at which the current measurement - // was started. - measureTime time.Time `state:".(unixTime)"` - - // copied is the number of bytes copied out of the receive - // buffers since this measure began. - copied int - - // prevCopied is the number of bytes copied out of the receive - // buffers in the previous RTT period. - prevCopied int - - // rtt is the non-smoothed minimum RTT as measured by observing the time - // between when a byte is first acknowledged and the receipt of data - // that is at least one window beyond the sequence number that was - // acknowledged. - rtt time.Duration - - // rttMeasureSeqNumber is the highest acceptable sequence number at the - // time this RTT measurement period began. - rttMeasureSeqNumber seqnum.Value - - // rttMeasureTime is the absolute time at which the current rtt - // measurement period began. - rttMeasureTime time.Time `state:".(unixTime)"` - - // disabled is true if an explicit receive buffer is set for the - // endpoint. - disabled bool -} - // ReceiveErrors collect segment receive errors within transport layer. type ReceiveErrors struct { tcpip.ReceiveErrors @@ -246,7 +211,7 @@ type ReceiveErrors struct { ListenOverflowAckDrop tcpip.StatCounter // ZeroRcvWindowState is the number of times we advertised - // a zero receive window when rcvList is full. + // a zero receive window when rcvQueue is full. ZeroRcvWindowState tcpip.StatCounter // WantZeroWindow is the number of times we wanted to advertise a @@ -309,18 +274,45 @@ type Stats struct { // marker interface. func (*Stats) IsEndpointStats() {} -// EndpointInfo holds useful information about a transport endpoint which -// can be queried by monitoring tools. This exists to allow tcp-only state to -// be exposed. +// sndQueueInfo implements a send queue. // // +stateify savable -type EndpointInfo struct { - stack.TransportEndpointInfo +type sndQueueInfo struct { + sndQueueMu sync.Mutex `state:"nosave"` + stack.TCPSndBufState + + // sndQueue holds segments that are ready to be sent. + sndQueue segmentList `state:"wait"` + + // sndWaker is used to signal the protocol goroutine when segments are + // added to the `sndQueue`. + sndWaker sleep.Waker `state:"manual"` + + // sndCloseWaker is used to notify the protocol goroutine when the send + // side is closed. + sndCloseWaker sleep.Waker `state:"manual"` } -// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo -// marker interface. -func (*EndpointInfo) IsEndpointInfo() {} +// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata. +// +// +stateify savable +type rcvQueueInfo struct { + rcvQueueMu sync.Mutex `state:"nosave"` + stack.TCPRcvBufState + + // rcvQueue is the queue for ready-for-delivery segments. This struct's + // mutex must be held in order append segments to list. + rcvQueue segmentList `state:"wait"` +} + +// +stateify savable +type accepted struct { + // NB: this could be an endpointList, but ilist only permits endpoints to + // belong to one list at a time, and endpoints are already stored in the + // dispatcher's list. + endpoints list.List `state:".([]*endpoint)"` + cap int +} // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to @@ -337,9 +329,9 @@ func (*EndpointInfo) IsEndpointInfo() {} // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects acceptedChan. -// e.rcvListMu -> Protects the rcvList and associated fields. -// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.acceptMu -> protects accepted. +// e.rcvQueueMu -> Protects e.rcvQueue and associated fields. +// e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. // // LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different @@ -362,7 +354,8 @@ func (*EndpointInfo) IsEndpointInfo() {} // // +stateify savable type endpoint struct { - EndpointInfo + stack.TCPEndpointStateInner + stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the @@ -395,38 +388,23 @@ type endpoint struct { // rcvReadMu synchronizes calls to Read. // - // mu and rcvListMu are temporarily released during data copying. rcvReadMu + // mu and rcvQueueMu are temporarily released during data copying. rcvReadMu // must be held during each read to ensure atomicity, so that multiple reads // do not interleave. // // rcvReadMu should be held before holding mu. rcvReadMu sync.Mutex `state:"nosave"` - // rcvListMu synchronizes access to rcvList. - // - // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - - // rcvList is the queue for ready-for-delivery segments. - // - // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data - // and removing segments from list. A range of segment can be determined, then - // temporarily release mu and rcvListMu while processing the segment range. - // This allows new segments to be appended to the list while processing. - // - // rcvListMu must be held to append segments to list. - rcvList segmentList `state:"wait"` - rcvClosed bool - // rcvBufSize is the total size of the receive buffer. - rcvBufSize int - // rcvBufUsed is the actual number of payload bytes held in the receive buffer - // not counting any overheads of the segments itself. NOTE: This will always - // be strictly <= rcvMemUsed below. - rcvBufUsed int - rcvAutoParams rcvBufAutoTuneParams + // rcvQueueInfo holds the implementation of the endpoint's receive buffer. + // The data within rcvQueueInfo should only be accessed while rcvReadMu, mu, + // and rcvQueueMu are held, in that stated order. While processing the segment + // range, you can determine a range and then temporarily release mu and + // rcvQueueMu, which allows new segments to be appended to the queue while + // processing. + rcvQueueInfo rcvQueueInfo // rcvMemUsed tracks the total amount of memory in use by received segments - // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // held in rcvQueue, pendingRcvdSegments and the segment queue. This is used to // compute the window and the actual available buffer space. This is distinct // from rcvBufUsed above which is the actual number of payload bytes held in // the buffer not including any segment overheads. @@ -488,33 +466,16 @@ type endpoint struct { // also true, and they're both protected by the mutex. workerCleanup bool - // sendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1 - sendTSOk bool - - // recentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - recentTS uint32 - - // recentTSTime is the unix time when we updated recentTS last. + // recentTSTime is the unix time when we last updated + // TCPEndpointStateInner.RecentTS. recentTSTime time.Time `state:".(unixTime)"` - // tsOffset is a randomized offset added to the value of the - // TSVal field in the timestamp option. - tsOffset uint32 - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags // tcpRecovery is the loss deteoction algorithm used by TCP. tcpRecovery tcpip.TCPRecovery - // sackPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - sackPermitted bool - // sack holds TCP SACK related information for this endpoint. sack SACKInfo @@ -550,32 +511,13 @@ type endpoint struct { // this value. windowClamp uint32 - // The following fields are used to manage the send buffer. When - // segments are ready to be sent, they are added to sndQueue and the - // protocol goroutine is signaled via sndWaker. - // - // When the send side is closed, the protocol goroutine is notified via - // sndCloseWaker, and sndClosed is set to true. - sndBufMu sync.Mutex `state:"nosave"` - sndBufUsed int - sndClosed bool - sndBufInQueue seqnum.Size - sndQueue segmentList `state:"wait"` - sndWaker sleep.Waker `state:"manual"` - sndCloseWaker sleep.Waker `state:"manual"` + // sndQueueInfo contains the implementation of the endpoint's send queue. + sndQueueInfo sndQueueInfo // cc stores the name of the Congestion Control algorithm to use for // this endpoint. cc tcpip.CongestionControlOption - // The following are used when a "packet too big" control packet is - // received. They are protected by sndBufMu. They are used to - // communicate to the main protocol goroutine how many such control - // messages have been received since the last notification was processed - // and what was the smallest MTU seen. - packetTooBigCount int - sndMTU int - // newSegmentWaker is used to indicate to the protocol goroutine that // it needs to wake up and handle new segments queued to it. newSegmentWaker sleep.Waker `state:"manual"` @@ -607,33 +549,26 @@ type endpoint struct { // listener. deferAccept time.Duration - // pendingAccepted is a synchronization primitive used to track number - // of connections that are queued up to be delivered to the accepted - // channel. We use this to ensure that all goroutines blocked on writing - // to the acceptedChan below terminate before we close acceptedChan. + // pendingAccepted tracks connections queued to be accepted. It is used to + // ensure such queued connections are terminated before the accepted queue is + // marked closed (by setting its capacity to zero). pendingAccepted sync.WaitGroup `state:"nosave"` - // acceptMu protects acceptedChan. + // acceptMu protects accepted. acceptMu sync.Mutex `state:"nosave"` // acceptCond is a condition variable that can be used to block on when - // acceptedChan is full and an endpoint is ready to be delivered. - // - // This condition variable is required because just blocking on sending - // to acceptedChan does not work in cases where endpoint.Listen is - // called twice with different backlog values. In such cases the channel - // is closed and a new one created. Any pending goroutines blocking on - // the write to the channel will panic. + // accepted is full and an endpoint is ready to be delivered. // // We use this condition variable to block/unblock goroutines which // tried to deliver an endpoint but couldn't because accept backlog was // full ( See: endpoint.deliverAccepted ). acceptCond *sync.Cond `state:"nosave"` - // acceptedChan is used by a listening endpoint protocol goroutine to + // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. - acceptedChan chan *endpoint `state:".([]*endpoint)"` + accepted accepted // The following are only used from the protocol goroutine, and // therefore don't need locks to protect them. @@ -779,7 +714,7 @@ func (e *endpoint) UnlockUser() { switch e.EndpointState() { case StateEstablished: - if err := e.handleSegments(true /* fastPath */); err != nil { + if err := e.handleSegmentsLocked(true /* fastPath */); err != nil { e.notifyProtocolGoroutine(notifyTickleWorker) } default: @@ -839,13 +774,13 @@ func (e *endpoint) EndpointState() EndpointState { // setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { - e.recentTS = recentTS + e.RecentTS = recentTS e.recentTSTime = time.Now() } // recentTimestamp returns the value of the recentTS field. func (e *endpoint) recentTimestamp() uint32 { - return e.recentTS + return e.RecentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -865,16 +800,17 @@ type keepalive struct { func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ stack: s, - EndpointInfo: EndpointInfo{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: header.TCPProtocolNumber, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + sndQueueInfo: sndQueueInfo{ + TCPSndBufState: stack.TCPSndBufState{ + SndMTU: int(math.MaxInt32), }, }, waiterQueue: waiterQueue, state: StateInitial, - rcvBufSize: DefaultReceiveBufferSize, - sndMTU: int(math.MaxInt32), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -886,10 +822,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) + e.ops.SetReceiveBufferSize(DefaultReceiveBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -898,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - e.rcvBufSize = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } var cs tcpip.CongestionControlOption @@ -908,7 +845,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { - e.rcvAutoParams.disabled = !bool(mrb) + e.rcvQueueInfo.RcvAutoParams.Disabled = !bool(mrb) } var de tcpip.TCPDelayEnabled @@ -933,7 +870,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.tsOffset = timeStampOffset() + e.TSOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) e.keepalive.timer.init(&e.keepalive.waker) @@ -959,10 +896,10 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result = mask case StateListen: - // Check if there's anything in the accepted channel. + // Check if there's anything in the accepted queue. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() - if len(e.acceptedChan) > 0 { + if e.accepted.endpoints.Len() != 0 { result |= waiter.ReadableEvents } e.acceptMu.Unlock() @@ -971,21 +908,21 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.WritableEvents) != 0 { - e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Lock() sndBufSize := e.getSendBufferSize() - if e.sndClosed || e.sndBufUsed < sndBufSize { + if e.sndQueueInfo.SndClosed || e.sndQueueInfo.SndBufUsed < sndBufSize { result |= waiter.WritableEvents } - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() } // Determine if the endpoint is readable if requested. if (mask & waiter.ReadableEvents) != 0 { - e.rcvListMu.Lock() - if e.rcvBufUsed > 0 || e.rcvClosed { + e.rcvQueueInfo.rcvQueueMu.Lock() + if e.rcvQueueInfo.RcvBufUsed > 0 || e.rcvQueueInfo.RcvClosed { result |= waiter.ReadableEvents } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() } } @@ -1093,15 +1030,15 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: e.boundDest, @@ -1145,22 +1082,22 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - if e.acceptedChan == nil { - e.acceptMu.Unlock() + acceptedCopy := e.accepted + e.accepted = accepted{} + e.acceptMu.Unlock() + + if acceptedCopy == (accepted{}) { return } - close(e.acceptedChan) - ch := e.acceptedChan - e.acceptedChan = nil + e.acceptCond.Broadcast() - e.acceptMu.Unlock() // Reset all connections that are waiting to be accepted. - for n := range ch { - n.notifyProtocolGoroutine(notifyReset) + for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { + n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) } // Wait for reset of all endpoints that are still waiting to be delivered to - // the now closed acceptedChan. + // the now closed accepted. e.pendingAccepted.Wait() } @@ -1176,7 +1113,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1184,8 +1121,8 @@ func (e *endpoint) cleanupLocked() { portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: e.boundDest, @@ -1247,19 +1184,19 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - if e.rcvAutoParams.disabled { - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + if e.rcvQueueInfo.RcvAutoParams.Disabled { + e.rcvQueueInfo.rcvQueueMu.Unlock() return } now := time.Now() - if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { - e.rcvAutoParams.copied += copied - e.rcvListMu.Unlock() + if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt { + e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied + e.rcvQueueInfo.rcvQueueMu.Unlock() return } - prevRTTCopied := e.rcvAutoParams.copied + copied - prevCopied := e.rcvAutoParams.prevCopied + prevRTTCopied := e.rcvQueueInfo.RcvAutoParams.CopiedBytes + copied + prevCopied := e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes rcvWnd := 0 if prevRTTCopied > prevCopied { // The minimal receive window based on what was copied by the app @@ -1291,24 +1228,25 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // We do not adjust downwards as that can cause the receiver to // reject valid data that might already be in flight as the // acceptable window will shrink. - if rcvWnd > e.rcvBufSize { - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvBufSize = rcvWnd - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + rcvBufSize := int(e.ops.GetReceiveBufferSize()) + if rcvWnd > rcvBufSize { + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + e.ops.SetReceiveBufferSize(int64(rcvWnd), false /* notify */) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(rcvWnd)) + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, rcvBufSize); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } - // We only update prevCopied when we grow the buffer because in cases - // where prevCopied > prevRTTCopied the existing buffer is already big + // We only update PrevCopiedBytes when we grow the buffer because in cases + // where PrevCopiedBytes > prevRTTCopied the existing buffer is already big // enough to handle the current rate and we don't need to do any // adjustments. - e.rcvAutoParams.prevCopied = prevRTTCopied + e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = prevRTTCopied } - e.rcvAutoParams.measureTime = now - e.rcvAutoParams.copied = 0 - e.rcvListMu.Unlock() + e.rcvQueueInfo.RcvAutoParams.MeasureTime = now + e.rcvQueueInfo.RcvAutoParams.CopiedBytes = 0 + e.rcvQueueInfo.rcvQueueMu.Unlock() } // SetOwner implements tcpip.Endpoint.SetOwner. @@ -1357,7 +1295,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult defer e.rcvReadMu.Unlock() // N.B. Here we get a range of segments to be processed. It is safe to not - // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we + // hold rcvQueueMu when processing, since we hold rcvReadMu to ensure only we // can remove segments from the list through commitRead(). first, last, serr := e.startRead() if serr != nil { @@ -1429,10 +1367,10 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() - bufUsed := e.rcvBufUsed + bufUsed := e.rcvQueueInfo.RcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { if s == StateError { if err := e.hardErrorLocked(); err != nil { @@ -1444,14 +1382,14 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { return nil, nil, &tcpip.ErrNotConnected{} } - if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.EndpointState().connected() { + if e.rcvQueueInfo.RcvBufUsed == 0 { + if e.rcvQueueInfo.RcvClosed || !e.EndpointState().connected() { return nil, nil, &tcpip.ErrClosedForReceive{} } return nil, nil, &tcpip.ErrWouldBlock{} } - return e.rcvList.Front(), e.rcvList.Back(), nil + return e.rcvQueueInfo.rcvQueue.Front(), e.rcvQueueInfo.rcvQueue.Back(), nil } // commitRead commits a read of done bytes and returns the next non-empty @@ -1467,39 +1405,39 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { func (e *endpoint) commitRead(done int) *segment { e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() memDelta := 0 - s := e.rcvList.Front() + s := e.rcvQueueInfo.rcvQueue.Front() for s != nil && s.data.Size() == 0 { - e.rcvList.Remove(s) + e.rcvQueueInfo.rcvQueue.Remove(s) // Memory is only considered released when the whole segment has been // read. memDelta += s.segMemSize() s.decRef() - s = e.rcvList.Front() + s = e.rcvQueueInfo.rcvQueue.Front() } - e.rcvBufUsed -= done + e.rcvQueueInfo.RcvBufUsed -= done if memDelta > 0 { // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } - return e.rcvList.Front() + return e.rcvQueueInfo.rcvQueue.Front() } // isEndpointWritableLocked checks if a given endpoint is writable // and also returns the number of bytes that can be written at this // moment. If the endpoint is not writable then it returns an error // indicating the reason why it's not writable. -// Caller must hold e.mu and e.sndBufMu +// Caller must hold e.mu and e.sndQueueMu func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { @@ -1519,12 +1457,12 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { } // Check if the connection has already been closed for sends. - if e.sndClosed { + if e.sndQueueInfo.SndClosed { return 0, &tcpip.ErrClosedForSend{} } sndBufSize := e.getSendBufferSize() - avail := sndBufSize - e.sndBufUsed + avail := sndBufSize - e.sndQueueInfo.SndBufUsed if avail <= 0 { return 0, &tcpip.ErrWouldBlock{} } @@ -1541,8 +1479,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp defer e.UnlockUser() nextSeg, n, err := func() (*segment, int, tcpip.Error) { - e.sndBufMu.Lock() - defer e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Lock() + defer e.sndQueueInfo.sndQueueMu.Unlock() avail, err := e.isEndpointWritableLocked() if err != nil { @@ -1557,8 +1495,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // available buffer space to be consumed by some other caller while we // are copying data in. if !opts.Atomic { - e.sndBufMu.Unlock() - defer e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Unlock() + defer e.sndQueueInfo.sndQueueMu.Lock() e.UnlockUser() defer e.LockUser() @@ -1600,10 +1538,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, v) + e.sndQueueInfo.SndBufUsed += len(v) + e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) + e.sndQueueInfo.sndQueue.PushBack(s) return e.drainSendQueueLocked(), len(v), nil }() @@ -1618,11 +1556,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // selectWindowLocked returns the new window without checking for shrinking or scaling // applied. -// Precondition: e.mu and e.rcvListMu must be held. -func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { - wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked()) - maxWindow := wndFromSpace(e.rcvBufSize) - wndFromUsedBytes := maxWindow - e.rcvBufUsed +// Precondition: e.mu and e.rcvQueueMu must be held. +func (e *endpoint) selectWindowLocked(rcvBufSize int) (wnd seqnum.Size) { + wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + maxWindow := wndFromSpace(rcvBufSize) + wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in // cases where we receive a lot of small segments the segment overhead is a @@ -1640,11 +1578,11 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { return seqnum.Size(newWnd) } -// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu. +// selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu. func (e *endpoint) selectWindow() (wnd seqnum.Size) { - e.rcvListMu.Lock() - wnd = e.selectWindowLocked() - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + wnd = e.selectWindowLocked(int(e.ops.GetReceiveBufferSize())) + e.rcvQueueInfo.rcvQueueMu.Unlock() return wnd } @@ -1662,9 +1600,9 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) { // above will be true if the new window is >= ACK threshold and false // otherwise. // -// Precondition: e.mu and e.rcvListMu must be held. -func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := int(e.selectWindowLocked()) +// Precondition: e.mu and e.rcvQueueMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int, rcvBufSize int) (crossed bool, above bool) { + newAvail := int(e.selectWindowLocked(rcvBufSize)) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 @@ -1673,7 +1611,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // rcvBufFraction is the inverse of the fraction of receive buffer size that // is used to decide if the available buffer space is now above it. const rcvBufFraction = 2 - if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + if wndThreshold := wndFromSpace(rcvBufSize / rcvBufFraction); threshold > wndThreshold { threshold = wndThreshold } switch { @@ -1700,7 +1638,7 @@ func (e *endpoint) OnReusePortSet(v bool) { } // OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet. -func (e *endpoint) OnKeepAliveSet(v bool) { +func (e *endpoint) OnKeepAliveSet(bool) { e.notifyProtocolGoroutine(notifyKeepaliveChanged) } @@ -1708,7 +1646,7 @@ func (e *endpoint) OnKeepAliveSet(v bool) { func (e *endpoint) OnDelayOptionSet(v bool) { if !v { // Handle delayed data. - e.sndWaker.Assert() + e.sndQueueInfo.sndWaker.Assert() } } @@ -1716,7 +1654,7 @@ func (e *endpoint) OnDelayOptionSet(v bool) { func (e *endpoint) OnCorkOptionSet(v bool) { if !v { // Handle the corked data. - e.sndWaker.Assert() + e.sndQueueInfo.sndWaker.Assert() } } @@ -1724,6 +1662,37 @@ func (e *endpoint) getSendBufferSize() int { return int(e.ops.GetSendBufferSize()) } +// OnSetReceiveBufferSize implements tcpip.SocketOptionsHandler.OnSetReceiveBufferSize. +func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) { + e.LockUser() + e.rcvQueueInfo.rcvQueueMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.RcvWndScale + } + if rcvBufSz>>scale == 0 { + rcvBufSz = 1 << scale + } + + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(int(oldSz))) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(int(rcvBufSz))) + e.rcvQueueInfo.RcvAutoParams.Disabled = true + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, int(rcvBufSz)); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + e.rcvQueueInfo.rcvQueueMu.Unlock() + e.UnlockUser() + return rcvBufSz +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -1767,56 +1736,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return &tcpip.ErrNotSupported{} } - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs tcpip.TCPReceiveBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) - } - - if v > rs.Max { - v = rs.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < rs.Min { - v = rs.Min - } - } else { - v = math.MaxInt32 - } - - e.LockUser() - e.rcvListMu.Lock() - - // Make sure the receive buffer size allows us to send a - // non-zero window size. - scale := uint8(0) - if e.rcv != nil { - scale = e.rcv.rcvWndScale - } - if v>>scale == 0 { - v = 1 << scale - } - - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvBufSize = v - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - - e.rcvAutoParams.disabled = true - - // Immediately send an ACK to uncork the sender silly window - // syndrome prevetion, when our available space grows above aMSS - // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) - } - - e.rcvListMu.Unlock() - e.UnlockUser() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1959,10 +1878,10 @@ func (e *endpoint) readyReceiveSize() (int, tcpip.Error) { return 0, &tcpip.ErrInvalidEndpointState{} } - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() - return e.rcvBufUsed, nil + return e.rcvQueueInfo.RcvBufUsed, nil } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -2002,12 +1921,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - v := e.rcvBufSize - e.rcvListMu.Unlock() - return v, nil - case tcpip.TTLOption: e.LockUser() v := int(e.ttl) @@ -2043,15 +1956,15 @@ func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { // the connection did not send and receive data, then RTT will // be zero. snd.rtt.Lock() - info.RTT = snd.rtt.srtt - info.RTTVar = snd.rtt.rttvar + info.RTT = snd.rtt.TCPRTTState.SRTT + info.RTTVar = snd.rtt.TCPRTTState.RTTVar snd.rtt.Unlock() - info.RTO = snd.rto + info.RTO = snd.RTO info.CcState = snd.state - info.SndSsthresh = uint32(snd.sndSsthresh) - info.SndCwnd = uint32(snd.sndCwnd) - info.ReorderSeen = snd.rc.reorderSeen + info.SndSsthresh = uint32(snd.Ssthresh) + info.SndCwnd = uint32(snd.SndCwnd) + info.ReorderSeen = snd.rc.Reord } e.UnlockUser() return info @@ -2096,7 +2009,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) e.UnlockUser() if err != nil { return err @@ -2204,20 +2117,20 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.TransportEndpointInfo.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.ID.LocalAddress = r.LocalAddress() - e.ID.RemoteAddress = r.RemoteAddress() - e.ID.RemotePort = addr.Port + e.TransportEndpointInfo.ID.LocalAddress = r.LocalAddress() + e.TransportEndpointInfo.ID.RemoteAddress = r.RemoteAddress() + e.TransportEndpointInfo.ID.RemotePort = addr.Port - if e.ID.LocalPort != 0 { + if e.TransportEndpointInfo.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -2226,19 +2139,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + sameAddr := e.TransportEndpointInfo.ID.LocalAddress == e.TransportEndpointInfo.ID.RemoteAddress // Calculate a port offset based on the destination IP/port and // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.Seed()) - h.Write([]byte(e.ID.LocalAddress)) - h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) - h.Write(portBuf) - portOffset := uint16(h.Sum32()) + + h := jenkins.Sum32(e.stack.Seed()) + for _, s := range [][]byte{ + []byte(e.ID.LocalAddress), + []byte(e.ID.RemoteAddress), + portBuf, + } { + // Per io.Writer.Write: + // + // Write must return a non-nil error if it returns n < len(p). + if _, err := h.Write(s); err != nil { + panic(err) + } + } + portOffset := h.Sum32() var twReuse tcpip.TCPTimeWaitReuseOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { @@ -2249,21 +2172,21 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { switch netProto { case header.IPv4ProtocolNumber: - reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + reuse = header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.LocalAddress) && header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.RemoteAddress) case header.IPv6ProtocolNumber: - reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + reuse = e.TransportEndpointInfo.ID.LocalAddress == header.IPv6Loopback && e.TransportEndpointInfo.ID.RemoteAddress == header.IPv6Loopback } } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) { - if sameAddr && p == e.ID.RemotePort { + if sameAddr && p == e.TransportEndpointInfo.ID.RemotePort { return false, nil } portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2273,7 +2196,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } - transEPID := e.ID + transEPID := e.TransportEndpointInfo.ID transEPID.LocalPort = p // Check if an endpoint is registered with demuxer in TIME-WAIT and if // we can reuse it. If we can't find a transport endpoint then we just @@ -2310,7 +2233,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2321,13 +2244,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp } } - id := e.ID + id := e.TransportEndpointInfo.ID id.LocalPort = p if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2342,13 +2265,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // Port picking successful. Save the details of // the selected port. - e.ID = id + e.TransportEndpointInfo.ID = id e.isPortReserved = true e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags e.boundDest = addr return true, nil }); err != nil { + e.stack.Stats().TCP.FailedPortReservations.Increment() return err } } @@ -2367,10 +2291,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // connection setting here. if !handshake { e.segmentQueue.mu.Lock() - for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { + for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.ID - e.sndWaker.Assert() + s.id = e.TransportEndpointInfo.ID + e.sndQueueInfo.sndWaker.Assert() } } e.segmentQueue.mu.Unlock() @@ -2412,10 +2336,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Close for read. if e.shutdownFlags&tcpip.ShutdownRead != 0 { // Mark read side as closed. - e.rcvListMu.Lock() - e.rcvClosed = true - rcvBufUsed := e.rcvBufUsed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = true + rcvBufUsed := e.rcvQueueInfo.RcvBufUsed + e.rcvQueueInfo.rcvQueueMu.Unlock() // If we're fully closed and we have unread data we need to abort // the connection with a RST. @@ -2429,10 +2353,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Close for write. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - e.sndBufMu.Lock() - if e.sndClosed { + e.sndQueueInfo.sndQueueMu.Lock() + if e.sndQueueInfo.SndClosed { // Already closed. - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() if e.EndpointState() == StateTimeWait { return &tcpip.ErrNotConnected{} } @@ -2440,12 +2364,12 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { } // Queue fin segment. - s := newOutgoingSegment(e.ID, nil) - e.sndQueue.PushBack(s) - e.sndBufInQueue++ + s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil) + e.sndQueueInfo.sndQueue.PushBack(s) + e.sndQueueInfo.SndBufInQueue++ // Mark endpoint as closed. - e.sndClosed = true - e.sndBufMu.Unlock() + e.sndQueueInfo.SndClosed = true + e.sndQueueInfo.sndQueueMu.Unlock() e.handleClose() } @@ -2458,9 +2382,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // // By not removing this endpoint from the demuxer mapping, we // ensure that any other bind to the same port fails, as on Linux. - e.rcvListMu.Lock() - e.rcvClosed = true - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = true + e.rcvQueueInfo.rcvQueueMu.Unlock() e.closePendingAcceptableConnectionsLocked() // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.ReadableEvents | waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) @@ -2474,6 +2398,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. func (e *endpoint) Listen(backlog int) tcpip.Error { + // Accept one more than the configured listen backlog to keep in parity with + // Linux. Ref, because of missing equality check here: + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937 + backlog++ err := e.listen(backlog) if err != nil { if !err.IgnoreStats() { @@ -2491,28 +2419,20 @@ func (e *endpoint) listen(backlog int) tcpip.Error { if e.EndpointState() == StateListen && !e.closed { e.acceptMu.Lock() defer e.acceptMu.Unlock() - if e.acceptedChan == nil { + if e.accepted == (accepted{}) { // listen is called after shutdown. - e.acceptedChan = make(chan *endpoint, backlog) + e.accepted.cap = backlog e.shutdownFlags = 0 - e.rcvListMu.Lock() - e.rcvClosed = false - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = false + e.rcvQueueInfo.rcvQueueMu.Unlock() } else { - // Adjust the size of the channel iff we can fix + // Adjust the size of the backlog iff we can fit // existing pending connections into the new one. - if len(e.acceptedChan) > backlog { + if e.accepted.endpoints.Len() > backlog { return &tcpip.ErrInvalidEndpointState{} } - if cap(e.acceptedChan) == backlog { - return nil - } - origChan := e.acceptedChan - e.acceptedChan = make(chan *endpoint, backlog) - close(origChan) - for ep := range origChan { - e.acceptedChan <- ep - } + e.accepted.cap = backlog } // Notify any blocked goroutines that they can attempt to @@ -2538,19 +2458,19 @@ func (e *endpoint) listen(backlog int) tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) - // The channel may be non-nil when we're restoring the endpoint, and it + // The queue may be non-zero when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() - if e.acceptedChan == nil { - e.acceptedChan = make(chan *endpoint, backlog) + if e.accepted == (accepted{}) { + e.accepted.cap = backlog } e.acceptMu.Unlock() @@ -2578,24 +2498,25 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - rcvClosed := e.rcvClosed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := e.rcvQueueInfo.RcvClosed + e.rcvQueueInfo.rcvQueueMu.Unlock() // Endpoint must be in listen state before it can accept connections. if rcvClosed || e.EndpointState() != StateListen { return nil, nil, &tcpip.ErrInvalidEndpointState{} } // Get the new accepted endpoint. - e.acceptMu.Lock() - defer e.acceptMu.Unlock() var n *endpoint - select { - case n = <-e.acceptedChan: - e.acceptCond.Signal() - default: + e.acceptMu.Lock() + if element := e.accepted.endpoints.Front(); element != nil { + n = e.accepted.endpoints.Remove(element).(*endpoint) + } + e.acceptMu.Unlock() + if n == nil { return nil, nil, &tcpip.ErrWouldBlock{} } + e.acceptCond.Signal() if peerAddr != nil { *peerAddr = n.getRemoteAddress() } @@ -2645,7 +2566,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { if nic == 0 { return &tcpip.ErrBadLocalAddress{} } - e.ID.LocalAddress = addr.Addr + e.TransportEndpointInfo.ID.LocalAddress = addr.Addr } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) @@ -2659,7 +2580,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { Dest: tcpip.FullAddress{}, } port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { - id := e.ID + id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a // listening endpoint bound with the same id and portFlags and bindToDevice @@ -2675,6 +2596,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { return true, nil }) if err != nil { + e.stack.Stats().TCP.FailedPortReservations.Increment() return err } @@ -2684,7 +2606,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { e.boundNICID = nic e.isPortReserved = true e.effectiveNetProtos = netProtos - e.ID.LocalPort = port + e.TransportEndpointInfo.ID.LocalPort = port // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2698,8 +2620,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { defer e.UnlockUser() return tcpip.FullAddress{ - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -2718,8 +2640,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, NIC: e.boundNICID, } } @@ -2758,13 +2680,13 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p Payload: pkt.Data().AsRange().ToOwnedView(), Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -2777,12 +2699,12 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p // HandleError implements stack.TransportEndpoint. func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { handlePacketTooBig := func(mtu uint32) { - e.sndBufMu.Lock() - e.packetTooBigCount++ - if v := int(mtu); v < e.sndMTU { - e.sndMTU = v + e.sndQueueInfo.sndQueueMu.Lock() + e.sndQueueInfo.PacketTooBigCount++ + if v := int(mtu); v < e.sndQueueInfo.SndMTU { + e.sndQueueInfo.SndMTU = v } - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) } @@ -2801,14 +2723,14 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { sendBufferSize := e.getSendBufferSize() - e.sndBufMu.Lock() - notify := e.sndBufUsed >= sendBufferSize>>1 - e.sndBufUsed -= v + e.sndQueueInfo.sndQueueMu.Lock() + notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1 + e.sndQueueInfo.SndBufUsed -= v // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 - e.sndBufMu.Unlock() + notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + e.sndQueueInfo.sndQueueMu.Unlock() if notify { e.waiterQueue.Notify(waiter.WritableEvents) @@ -2819,58 +2741,50 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() if s != nil { - e.rcvBufUsed += s.payloadSize() + e.rcvQueueInfo.RcvBufUsed += s.payloadSize() s.incRef() - e.rcvList.PushBack(s) + e.rcvQueueInfo.rcvQueue.PushBack(s) } else { - e.rcvClosed = true + e.rcvQueueInfo.RcvClosed = true } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() e.waiterQueue.Notify(waiter.ReadableEvents) } // receiveBufferAvailableLocked calculates how many bytes are still available // in the receive buffer. -// rcvListMu must be held when this function is called. -func (e *endpoint) receiveBufferAvailableLocked() int { +// rcvQueueMu must be held when this function is called. +func (e *endpoint) receiveBufferAvailableLocked(rcvBufSize int) int { // We may use more bytes than the buffer size when the receive buffer // shrinks. memUsed := e.receiveMemUsed() - if memUsed >= e.rcvBufSize { + if memUsed >= rcvBufSize { return 0 } - return e.rcvBufSize - memUsed + return rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the // receive buffer based on the actual memory used by all segments held in // receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { - e.rcvListMu.Lock() - available := e.receiveBufferAvailableLocked() - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + available := e.receiveBufferAvailableLocked(int(e.ops.GetReceiveBufferSize())) + e.rcvQueueInfo.rcvQueueMu.Unlock() return available } // receiveBufferUsed returns the amount of in-use receive buffer. func (e *endpoint) receiveBufferUsed() int { - e.rcvListMu.Lock() - used := e.rcvBufUsed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + used := e.rcvQueueInfo.RcvBufUsed + e.rcvQueueInfo.rcvQueueMu.Unlock() return used } -// receiveBufferSize returns the current size of the receive buffer. -func (e *endpoint) receiveBufferSize() int { - e.rcvListMu.Lock() - size := e.rcvBufSize - e.rcvListMu.Unlock() - return size -} - // receiveMemUsed returns the total memory in use by segments held by this // endpoint. func (e *endpoint) receiveMemUsed() int { @@ -2899,11 +2813,11 @@ func (e *endpoint) maxReceiveBufferSize() int { // receiveBuffer otherwise we use the max permissible receive buffer size to // compute the scale. func (e *endpoint) rcvWndScaleForHandshake() int { - bufSizeForScale := e.receiveBufferSize() + bufSizeForScale := e.ops.GetReceiveBufferSize() - e.rcvListMu.Lock() - autoTuningDisabled := e.rcvAutoParams.disabled - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled + e.rcvQueueInfo.rcvQueueMu.Unlock() if autoTuningDisabled { return FindWndScale(seqnum.Size(bufSizeForScale)) } @@ -2914,7 +2828,7 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + if e.SendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { e.setRecentTimestamp(tsVal) } } @@ -2924,7 +2838,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, // initializes the recentTS with the value provided in synOpts.TSval. func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { - e.sendTSOk = true + e.SendTSOk = true e.setRecentTimestamp(synOpts.TSVal) } } @@ -2932,7 +2846,7 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(time.Now(), e.tsOffset) + return tcpTimeStamp(time.Now(), e.TSOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is @@ -2971,7 +2885,7 @@ func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { return } if bool(v) && synOpts.SACKPermitted { - e.sackPermitted = true + e.SACKPermitted = true } } @@ -2985,118 +2899,46 @@ func (e *endpoint) maxOptionSize() (size int) { return size } -// completeState makes a full copy of the endpoint and returns it. This is used -// before invoking the probe. The state returned may not be fully consistent if -// there are intervening syscalls when the state is being copied. -func (e *endpoint) completeState() stack.TCPEndpointState { - var s stack.TCPEndpointState - s.SegTime = time.Now() - - // Copy EndpointID. - s.ID = stack.TCPEndpointID(e.ID) - - // Copy endpoint rcv state. - e.rcvListMu.Lock() - s.RcvBufSize = e.rcvBufSize - s.RcvBufUsed = e.rcvBufUsed - s.RcvClosed = e.rcvClosed - s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime - s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied - s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied - s.RcvAutoParams.RTT = e.rcvAutoParams.rtt - s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber - s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime - s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled - e.rcvListMu.Unlock() - - // Endpoint TCP Option state. - s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTimestamp() - s.TSOffset = e.tsOffset - s.SACKPermitted = e.sackPermitted +// completeStateLocked makes a full copy of the endpoint and returns it. This is +// used before invoking the probe. +// +// Precondition: e.mu must be held. +func (e *endpoint) completeStateLocked() stack.TCPEndpointState { + s := stack.TCPEndpointState{ + TCPEndpointStateInner: e.TCPEndpointStateInner, + ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID), + SegTime: time.Now(), + Receiver: e.rcv.TCPReceiverState, + Sender: e.snd.TCPSenderState, + } + + sndBufSize := e.getSendBufferSize() + // Copy the send buffer atomically. + e.sndQueueInfo.sndQueueMu.Lock() + s.SndBufState = e.sndQueueInfo.TCPSndBufState + s.SndBufState.SndBufSize = sndBufSize + e.sndQueueInfo.sndQueueMu.Unlock() + + // Copy the receive buffer atomically. + e.rcvQueueInfo.rcvQueueMu.Lock() + s.RcvBufState = e.rcvQueueInfo.TCPRcvBufState + e.rcvQueueInfo.rcvQueueMu.Unlock() + + // Copy the endpoint TCP Option state. s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks]) s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() - // Copy endpoint send state. - sndBufSize := e.getSendBufferSize() - e.sndBufMu.Lock() - s.SndBufSize = sndBufSize - s.SndBufUsed = e.sndBufUsed - s.SndClosed = e.sndClosed - s.SndBufInQueue = e.sndBufInQueue - s.PacketTooBigCount = e.packetTooBigCount - s.SndMTU = e.sndMTU - e.sndBufMu.Unlock() - - // Copy receiver state. - s.Receiver = stack.TCPReceiverState{ - RcvNxt: e.rcv.rcvNxt, - RcvAcc: e.rcv.rcvAcc, - RcvWndScale: e.rcv.rcvWndScale, - PendingBufUsed: e.rcv.pendingBufUsed, - } - - // Copy sender state. - s.Sender = stack.TCPSenderState{ - LastSendTime: e.snd.lastSendTime, - DupAckCount: e.snd.dupAckCount, - FastRecovery: stack.TCPFastRecoveryState{ - Active: e.snd.fr.active, - First: e.snd.fr.first, - Last: e.snd.fr.last, - MaxCwnd: e.snd.fr.maxCwnd, - HighRxt: e.snd.fr.highRxt, - RescueRxt: e.snd.fr.rescueRxt, - }, - SndCwnd: e.snd.sndCwnd, - Ssthresh: e.snd.sndSsthresh, - SndCAAckCount: e.snd.sndCAAckCount, - Outstanding: e.snd.outstanding, - SackedOut: e.snd.sackedOut, - SndWnd: e.snd.sndWnd, - SndUna: e.snd.sndUna, - SndNxt: e.snd.sndNxt, - RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, - RTTMeasureTime: e.snd.rttMeasureTime, - Closed: e.snd.closed, - RTO: e.snd.rto, - MaxPayloadSize: e.snd.maxPayloadSize, - SndWndScale: e.snd.sndWndScale, - MaxSentAck: e.snd.maxSentAck, - } e.snd.rtt.Lock() - s.Sender.SRTT = e.snd.rtt.srtt - s.Sender.SRTTInited = e.snd.rtt.srttInited + s.Sender.RTTState = e.snd.rtt.TCPRTTState e.snd.rtt.Unlock() if cubic, ok := e.snd.cc.(*cubicState); ok { - s.Sender.Cubic = stack.TCPCubicState{ - WMax: cubic.wMax, - WLastMax: cubic.wLastMax, - T: cubic.t, - TimeSinceLastCongestion: time.Since(cubic.t), - C: cubic.c, - K: cubic.k, - Beta: cubic.beta, - WC: cubic.wC, - WEst: cubic.wEst, - } + s.Sender.Cubic = cubic.TCPCubicState + s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T) } - rc := &e.snd.rc - s.Sender.RACKState = stack.TCPRACKState{ - XmitTime: rc.xmitTime, - EndSequence: rc.endSequence, - FACK: rc.fack, - RTT: rc.rtt, - Reord: rc.reorderSeen, - DSACKSeen: rc.dsackSeen, - ReoWnd: rc.reoWnd, - ReoWndIncr: rc.reoWndIncr, - ReoWndPersist: rc.reoWndPersist, - RTTSeq: rc.rttSeq, - } + s.Sender.RACKState = e.snd.rc.TCPRACKState return s } @@ -3200,3 +3042,17 @@ func (e *endpoint) allowOutOfWindowAck() bool { e.lastOutOfWindowAckTime = now return true } + +// GetTCPReceiveBufferLimits is used to get send buffer size limits for TCP. +func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOption { + var ss tcpip.TCPReceiveBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.ReceiveBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index a53d76917..6e9777fe4 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -58,7 +58,7 @@ func (e *endpoint) beforeSave() { if !e.route.HasSaveRestoreCapability() { if !e.route.HasDisconncetOkCapability() { panic(&tcpip.ErrSaveRejection{ - Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), + Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort), }) } e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) @@ -67,7 +67,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() } if !e.workerRunning { - // The endpoint must be in acceptedChan or has been just + // The endpoint must be in the accepted queue or has been just // disconnected and closed. break } @@ -88,7 +88,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() } if e.workerRunning { - panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID)) } default: panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) @@ -99,37 +99,19 @@ func (e *endpoint) beforeSave() { } } -// saveAcceptedChan is invoked by stateify. -func (e *endpoint) saveAcceptedChan() []*endpoint { - if e.acceptedChan == nil { - return nil - } - acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan)) - for i := 0; i < len(acceptedEndpoints); i++ { - select { - case ep := <-e.acceptedChan: - acceptedEndpoints[i] = ep - default: - panic("endpoint acceptedChan buffer got consumed by background context") - } - } - for i := 0; i < len(acceptedEndpoints); i++ { - select { - case e.acceptedChan <- acceptedEndpoints[i]: - default: - panic("endpoint acceptedChan buffer got populated by background context") - } +// saveEndpoints is invoked by stateify. +func (a *accepted) saveEndpoints() []*endpoint { + acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) + for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { + acceptedEndpoints[i] = e.Value.(*endpoint) } return acceptedEndpoints } -// loadAcceptedChan is invoked by stateify. -func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { - if cap(acceptedEndpoints) > 0 { - e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints)) - for _, ep := range acceptedEndpoints { - e.acceptedChan <- ep - } +// loadEndpoints is invoked by stateify. +func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { + for _, ep := range acceptedEndpoints { + a.endpoints.PushBack(ep) } } @@ -183,7 +165,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { @@ -198,14 +180,14 @@ func (e *endpoint) Resume(s *stack.Stack) { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) + if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) { + panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max)) } } } bind := func() { - addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}) if err != nil { panic("unable to parse BindAddr: " + err.String()) } @@ -231,19 +213,19 @@ func (e *endpoint) Resume(s *stack.Stack) { case epState.connected(): bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.ID.RemoteAddress + e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning) + err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning) if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } @@ -263,7 +245,7 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() - backlog := cap(e.acceptedChan) + backlog := e.accepted.cap if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } @@ -281,7 +263,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}) + err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}) if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } @@ -328,23 +310,3 @@ func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) } - -// saveMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.measureTime = time.Unix(unix.second, unix.nano) -} - -// saveRttMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime { - return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()} -} - -// loadRttMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) { - r.rttMeasureTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a4667906..fe0d7f10f 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -75,63 +75,6 @@ const ( ccCubic = "cubic" ) -// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The -// value is protected by a mutex so that we can increment only when it's -// guaranteed not to go above a threshold. -type synRcvdCounter struct { - sync.Mutex - value uint64 - pending sync.WaitGroup - threshold uint64 -} - -// inc tries to increment the global number of endpoints in SYN-RCVD state. It -// succeeds if the increment doesn't make the count go beyond the threshold, and -// fails otherwise. -func (s *synRcvdCounter) inc() bool { - s.Lock() - defer s.Unlock() - if s.value >= s.threshold { - return false - } - - s.pending.Add(1) - s.value++ - - return true -} - -// dec atomically decrements the global number of endpoints in SYN-RCVD -// state. It must only be called if a previous call to inc succeeded. -func (s *synRcvdCounter) dec() { - s.Lock() - defer s.Unlock() - s.value-- - s.pending.Done() -} - -// synCookiesInUse returns true if the synRcvdCount is greater than -// SynRcvdCountThreshold. -func (s *synRcvdCounter) synCookiesInUse() bool { - s.Lock() - defer s.Unlock() - return s.value >= s.threshold -} - -// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. -func (s *synRcvdCounter) SetThreshold(threshold uint64) { - s.Lock() - defer s.Unlock() - s.threshold = threshold -} - -// Threshold returns the current value of synRcvdCounter.Threhsold. -func (s *synRcvdCounter) Threshold() uint64 { - s.Lock() - defer s.Unlock() - return s.threshold -} - type protocol struct { stack *stack.Stack @@ -139,6 +82,7 @@ type protocol struct { sackEnabled bool recovery tcpip.TCPRecovery delayEnabled bool + alwaysUseSynCookies bool sendBufferSize tcpip.TCPSendBufferSizeRangeOption recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string @@ -150,7 +94,6 @@ type protocol struct { minRTO time.Duration maxRTO time.Duration maxRetries uint32 - synRcvdCount synRcvdCounter synRetries uint8 dispatcher dispatcher } @@ -373,9 +316,9 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip p.mu.Unlock() return nil - case *tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPAlwaysUseSynCookies: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(*v)) + p.alwaysUseSynCookies = bool(*v) p.mu.Unlock() return nil @@ -480,9 +423,9 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er p.mu.RUnlock() return nil - case *tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPAlwaysUseSynCookies: p.mu.RLock() - *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + *v = tcpip.TCPAlwaysUseSynCookies(p.alwaysUseSynCookies) p.mu.RUnlock() return nil @@ -507,12 +450,6 @@ func (p *protocol) Wait() { p.dispatcher.wait() } -// SynRcvdCounter returns a reference to the synRcvdCount for this protocol -// instance. -func (p *protocol) SynRcvdCounter() *synRcvdCounter { - return &p.synRcvdCount -} - // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { return parse.TCP(pkt) @@ -537,7 +474,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { lingerTimeout: DefaultTCPLingerTimeout, timeWaitTimeout: DefaultTCPTimeWaitTimeout, timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly, - synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 0a0d5f7a1..9e332dcf7 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -46,54 +47,16 @@ const ( // // +stateify savable type rackControl struct { - // dsackSeen indicates if the connection has seen a DSACK. - dsackSeen bool - - // endSequence is the ending TCP sequence number of the most recent - // acknowledged segment. - endSequence seqnum.Value + stack.TCPRACKState // exitedRecovery indicates if the connection is exiting loss recovery. // This flag is set if the sender is leaving the recovery after // receiving an ACK and is reset during updating of reorder window. exitedRecovery bool - // fack is the highest selectively or cumulatively acknowledged - // sequence. - fack seqnum.Value - // minRTT is the estimated minimum RTT of the connection. minRTT time.Duration - // reorderSeen indicates if reordering has been detected on this - // connection. - reorderSeen bool - - // reoWnd is the reordering window time used for recording packet - // transmission times. It is used to defer the moment at which RACK - // marks a packet lost. - reoWnd time.Duration - - // reoWndIncr is the multiplier applied to adjust reorder window. - reoWndIncr uint8 - - // reoWndPersist is the number of loss recoveries before resetting - // reorder window. - reoWndPersist int8 - - // rtt is the RTT of the most recently delivered packet on the - // connection (either cumulatively acknowledged or selectively - // acknowledged) that was not marked invalid as a possible spurious - // retransmission. - rtt time.Duration - - // rttSeq is the SND.NXT when rtt is updated. - rttSeq seqnum.Value - - // xmitTime is the latest transmission timestamp of the most recent - // acknowledged segment. - xmitTime time.Time `state:".(unixTime)"` - // tlpRxtOut indicates whether there is an unacknowledged // TLP retransmission. tlpRxtOut bool @@ -108,8 +71,8 @@ type rackControl struct { // init initializes RACK specific fields. func (rc *rackControl) init(snd *sender, iss seqnum.Value) { - rc.fack = iss - rc.reoWndIncr = 1 + rc.FACK = iss + rc.ReoWndIncr = 1 rc.snd = snd } @@ -117,7 +80,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := time.Now().Sub(seg.xmitTime) - tsOffset := rc.snd.ep.tsOffset + tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: @@ -138,7 +101,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { } } - rc.rtt = rtt + rc.RTT = rtt // The sender can either track a simple global minimum of all RTT // measurements from the connection, or a windowed min-filtered value @@ -152,9 +115,9 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // ending sequence number of the packet which has been acknowledged // most recently. endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { - rc.xmitTime = seg.xmitTime - rc.endSequence = endSeq + if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + rc.XmitTime = seg.xmitTime + rc.EndSequence = endSeq } } @@ -171,18 +134,18 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // is identified. func (rc *rackControl) detectReorder(seg *segment) { endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.fack.LessThan(endSeq) { - rc.fack = endSeq + if rc.FACK.LessThan(endSeq) { + rc.FACK = endSeq return } - if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 { - rc.reorderSeen = true + if endSeq.LessThan(rc.FACK) && seg.xmitCount == 1 { + rc.Reord = true } } func (rc *rackControl) setDSACKSeen(dsackSeen bool) { - rc.dsackSeen = dsackSeen + rc.DSACKSeen = dsackSeen } // shouldSchedulePTO dictates whether we should schedule a PTO or not. @@ -191,7 +154,7 @@ func (s *sender) shouldSchedulePTO() bool { // Schedule PTO only if RACK loss detection is enabled. return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 && // The connection supports SACK. - s.ep.sackPermitted && + s.ep.SACKPermitted && // The connection is not in loss recovery. (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) && // The connection has no SACKed sequences in the SACK scoreboard. @@ -203,9 +166,9 @@ func (s *sender) shouldSchedulePTO() bool { func (s *sender) schedulePTO() { pto := time.Second s.rtt.Lock() - if s.rtt.srttInited && s.rtt.srtt > 0 { - pto = s.rtt.srtt * 2 - if s.outstanding == 1 { + if s.rtt.TCPRTTState.SRTTInited && s.rtt.TCPRTTState.SRTT > 0 { + pto = s.rtt.TCPRTTState.SRTT * 2 + if s.Outstanding == 1 { pto += wcDelayedACKTimeout } } @@ -230,10 +193,10 @@ func (s *sender) probeTimerExpired() tcpip.Error { } var dataSent bool - if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.outstanding < s.sndCwnd { - dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd)) + if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.Outstanding < s.SndCwnd { + dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd)) if dataSent { - s.outstanding += s.pCount(s.writeNext, s.maxPayloadSize) + s.Outstanding += s.pCount(s.writeNext, s.MaxPayloadSize) s.writeNext = s.writeNext.Next() } } @@ -255,10 +218,10 @@ func (s *sender) probeTimerExpired() tcpip.Error { } if highestSeqXmit != nil { - dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd)) + dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd)) if dataSent { s.rc.tlpRxtOut = true - s.rc.tlpHighRxt = s.sndNxt + s.rc.tlpHighRxt = s.SndNxt } } } @@ -274,7 +237,7 @@ func (s *sender) probeTimerExpired() tcpip.Error { // and updates TLP state accordingly. // See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3. func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { - if !(s.ep.sackPermitted && s.rc.tlpRxtOut) { + if !(s.ep.SACKPermitted && s.rc.tlpRxtOut) { return } @@ -317,13 +280,13 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { - dsackSeen := rc.dsackSeen + dsackSeen := rc.DSACKSeen snd := rc.snd // React to DSACK once per round trip. // If SND.UNA < RACK.rtt_seq: // RACK.dsack = false - if snd.sndUna.LessThan(rc.rttSeq) { + if snd.SndUna.LessThan(rc.RTTSeq) { dsackSeen = false } @@ -333,18 +296,18 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // RACK.rtt_seq = SND.NXT // RACK.reo_wnd_persist = 16 if dsackSeen { - rc.reoWndIncr++ + rc.ReoWndIncr++ dsackSeen = false - rc.rttSeq = snd.sndNxt - rc.reoWndPersist = tcpRACKRecoveryThreshold + rc.RTTSeq = snd.SndNxt + rc.ReoWndPersist = tcpRACKRecoveryThreshold } else if rc.exitedRecovery { // Else if exiting loss recovery: // RACK.reo_wnd_persist -= 1 // If RACK.reo_wnd_persist <= 0: // RACK.reo_wnd_incr = 1 - rc.reoWndPersist-- - if rc.reoWndPersist <= 0 { - rc.reoWndIncr = 1 + rc.ReoWndPersist-- + if rc.ReoWndPersist <= 0 { + rc.ReoWndIncr = 1 } rc.exitedRecovery = false } @@ -358,14 +321,14 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // Else if RACK.pkts_sacked >= RACK.dupthresh: // RACK.reo_wnd = 0 // return - if !rc.reorderSeen { + if !rc.Reord { if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery { - rc.reoWnd = 0 + rc.ReoWnd = 0 return } - if snd.sackedOut >= nDupAckThreshold { - rc.reoWnd = 0 + if snd.SackedOut >= nDupAckThreshold { + rc.ReoWnd = 0 return } } @@ -374,11 +337,11 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr // RACK.reo_wnd = min(RACK.reo_wnd, SRTT) snd.rtt.Lock() - srtt := snd.rtt.srtt + srtt := snd.rtt.TCPRTTState.SRTT snd.rtt.Unlock() - rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr)) - if srtt < rc.reoWnd { - rc.reoWnd = srtt + rc.ReoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.ReoWndIncr)) + if srtt < rc.ReoWnd { + rc.ReoWnd = srtt } } @@ -403,8 +366,8 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int { } endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if seg.xmitTime.Before(rc.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { - timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.rtt + rc.reoWnd + if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd if timeRemaining <= 0 { seg.lost = true numLost++ @@ -435,7 +398,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error { } fastRetransmit := false - if !rc.snd.fr.active { + if !rc.snd.FastRecovery.Active { rc.snd.cc.HandleLossDetected() rc.snd.enterRecovery() fastRetransmit = true @@ -471,15 +434,15 @@ func (rc *rackControl) DoRecovery(_ *segment, fastRetransmit bool) { } // Check the congestion window after entering recovery. - if snd.outstanding >= snd.sndCwnd { + if snd.Outstanding >= snd.SndCwnd { break } - if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.sndUna.Add(snd.sndWnd)); !sent { + if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.SndUna.Add(snd.SndWnd)); !sent { break } dataSent = true - snd.outstanding += snd.pCount(seg, snd.maxPayloadSize) + snd.Outstanding += snd.pCount(seg, snd.MaxPayloadSize) } snd.postXmit(dataSent, true /* shouldScheduleProbe */) diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go deleted file mode 100644 index c9dc7e773..000000000 --- a/pkg/tcpip/transport/tcp/rack_state.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// saveXmitTime is invoked by stateify. -func (rc *rackControl) saveXmitTime() unixTime { - return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (rc *rackControl) loadXmitTime(unix unixTime) { - rc.xmitTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index bc6793fc6..ee2c08cd6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // receiver holds the state necessary to receive TCP segments and turn them @@ -29,26 +30,15 @@ import ( // // +stateify savable type receiver struct { + stack.TCPReceiverState ep *endpoint - rcvNxt seqnum.Value - - // rcvAcc is one beyond the last acceptable sequence number. That is, - // the "largest" sequence value that the receiver has announced to the - // its peer that it's willing to accept. This may be different than - // rcvNxt + rcvWnd if the receive window is reduced; in that case we - // have to reduce the window as we receive more data instead of - // shrinking it. - rcvAcc seqnum.Value - // rcvWnd is the non-scaled receive window last advertised to the peer. rcvWnd seqnum.Size - // rcvWUP is the rcvNxt value at the last window update sent. + // rcvWUP is the RcvNxt value at the last window update sent. rcvWUP seqnum.Value - rcvWndScale uint8 - // prevBufused is the snapshot of endpoint rcvBufUsed taken when we // advertise a receive window. prevBufUsed int @@ -58,9 +48,6 @@ type receiver struct { // pendingRcvdSegments is bounded by the receive buffer size of the // endpoint. pendingRcvdSegments segmentHeap - // pendingBufUsed tracks the total number of bytes (including segment - // overhead) currently queued in pendingRcvdSegments. - pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` @@ -68,12 +55,14 @@ type receiver struct { func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), + ep: ep, + TCPReceiverState: stack.TCPReceiverState{ + RcvNxt: irs + 1, + RcvAcc: irs.Add(rcvWnd + 1), + RcvWndScale: rcvWndScale, + }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - rcvWndScale: rcvWndScale, lastRcvdAckTime: time.Now(), } } @@ -84,34 +73,34 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // r.rcvWnd could be much larger than the window size we advertised in our // outgoing packets, we should use what we have advertised for acceptability // test. - scaledWindowSize := r.rcvWnd >> r.rcvWndScale + scaledWindowSize := r.rcvWnd >> r.RcvWndScale if scaledWindowSize > math.MaxUint16 { // This is what we actually put in the Window field. scaledWindowSize = math.MaxUint16 } - advertisedWindowSize := scaledWindowSize << r.rcvWndScale - return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) + advertisedWindowSize := scaledWindowSize << r.RcvWndScale + return header.Acceptable(segSeq, segLen, r.RcvNxt, r.RcvNxt.Add(advertisedWindowSize)) } // currentWindow returns the available space in the window that was advertised // last to our peer. func (r *receiver) currentWindow() (curWnd seqnum.Size) { endOfWnd := r.rcvWUP.Add(r.rcvWnd) - if endOfWnd.LessThan(r.rcvNxt) { - // return 0 if r.rcvNxt is past the end of the previously advertised window. + if endOfWnd.LessThan(r.RcvNxt) { + // return 0 if r.RcvNxt is past the end of the previously advertised window. // This can happen because we accept a large segment completely even if // accepting it causes it to partially exceed the advertised window. return 0 } - return r.rcvNxt.Size(endOfWnd) + return r.RcvNxt.Size(endOfWnd) } // getSendParams returns the parameters needed by the sender when building // segments to send. -func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { +func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { newWnd := r.ep.selectWindow() curWnd := r.currentWindow() - unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt)) + unackLen := int(r.ep.snd.MaxSentAck.Size(r.RcvNxt)) bufUsed := r.ep.receiveBufferUsed() // Grow the right edge of the window only for payloads larger than the @@ -139,18 +128,18 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // edge, as we are still advertising a window that we think can be serviced. toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed - // Update rcvAcc only if new window is > previously advertised window. We + // Update RcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we // would end up dropping bytes that might already be in flight. // ==================================================== sequence space. // ^ ^ ^ ^ - // rcvWUP rcvNxt rcvAcc new rcvAcc + // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow { - // If the new window moves the right edge, then update rcvAcc. - r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) + if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + // If the new window moves the right edge, then update RcvAcc. + r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -162,9 +151,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd - r.rcvWUP = r.rcvNxt + r.rcvWUP = r.RcvNxt r.prevBufUsed = bufUsed - scaledWnd := r.rcvWnd >> r.rcvWndScale + scaledWnd := r.rcvWnd >> r.RcvWndScale if scaledWnd == 0 { // Increment a metric if we are advertising an actual zero window. r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() @@ -177,9 +166,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Ensure that the stashed receive window always reflects what // is being advertised. - r.rcvWnd = scaledWnd << r.rcvWndScale + r.rcvWnd = scaledWnd << r.RcvWndScale } - return r.rcvNxt, scaledWnd + return r.RcvNxt, scaledWnd } // nonZeroWindow is called when the receive window grows from zero to nonzero; @@ -201,13 +190,13 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // If the segment doesn't include the seqnum we're expecting to // consume now, we're missing a segment. We cannot proceed until // we receive that segment though. - if !r.rcvNxt.InWindow(segSeq, segLen) { + if !r.RcvNxt.InWindow(segSeq, segLen) { return false } // Trim segment to eliminate already acknowledged data. - if segSeq.LessThan(r.rcvNxt) { - diff := segSeq.Size(r.rcvNxt) + if segSeq.LessThan(r.RcvNxt) { + diff := segSeq.Size(r.RcvNxt) segLen -= diff segSeq.UpdateForward(diff) s.sequenceNumber.UpdateForward(diff) @@ -217,35 +206,35 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Move segment to ready-to-deliver list. Wakeup any waiters. r.ep.readyToRead(s) - } else if segSeq != r.rcvNxt { + } else if segSeq != r.RcvNxt { return false } // Update the segment that we're expecting to consume. - r.rcvNxt = segSeq.Add(segLen) + r.RcvNxt = segSeq.Add(segLen) // In cases of a misbehaving sender which could send more than the // advertised window, we could end up in a situation where we get a // segment that exceeds the window advertised. Instead of partially // accepting the segment and discarding bytes beyond the advertised - // window, we accept the whole segment and make sure r.rcvAcc is moved - // forward to match r.rcvNxt to indicate that the window is now closed. + // window, we accept the whole segment and make sure r.RcvAcc is moved + // forward to match r.RcvNxt to indicate that the window is now closed. // // In absence of this check the r.acceptable() check fails and accepts // segments that should be dropped because rcvWnd is calculated as - // the size of the interval (rcvNxt, rcvAcc] which becomes extremely - // large if rcvAcc is ever less than rcvNxt. - if r.rcvAcc.LessThan(r.rcvNxt) { - r.rcvAcc = r.rcvNxt + // the size of the interval (RcvNxt, RcvAcc] which becomes extremely + // large if RcvAcc is ever less than RcvNxt. + if r.RcvAcc.LessThan(r.RcvNxt) { + r.RcvAcc = r.RcvNxt } // Trim SACK Blocks to remove any SACK information that covers // sequence numbers that have been consumed. - TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { - r.rcvNxt++ + r.RcvNxt++ // Send ACK immediately. r.ep.snd.sendAck() @@ -260,7 +249,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -280,7 +269,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { - r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() + r.PendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() // Note that slice truncation does not allow garbage collection of @@ -295,7 +284,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -323,40 +312,40 @@ func (r *receiver) updateRTT() { // estimate the round-trip time by observing the time between when a byte // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. - r.ep.rcvListMu.Lock() - if r.ep.rcvAutoParams.rttMeasureTime.IsZero() { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { // New measurement. - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) { - r.ep.rcvListMu.Unlock() + if r.RcvNxt.LessThan(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber) { + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime) + rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. - if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt { - r.ep.rcvAutoParams.rtt = rtt + if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { + r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { - r.ep.rcvListMu.Lock() - rcvClosed := r.ep.rcvClosed || r.closed - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := r.ep.rcvQueueInfo.RcvClosed || r.closed + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { case StateCloseWait, StateClosing, StateLastAck: - if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + if !s.sequenceNumber.LessThanEq(r.RcvNxt) { // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -384,17 +373,17 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // The ESTABLISHED state processing is here where if the ACK check // fails, we ignore the packet: // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + if r.ep.snd.SndNxt.LessThan(s.ackNumber) { r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., - // SHUT_RD) then any data past the rcvNxt should + // SHUT_RD) then any data past the RcvNxt should // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + if state != StateCloseWait && rcvClosed && r.RcvNxt.LessThan(endDataSeq) { return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { @@ -403,7 +392,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If it's a retransmission of an old data segment // or a pure ACK then allow it. - if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.RcvNxt) || s.logicalLen() == 0 { break } @@ -413,7 +402,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // then the only acceptable segment is a // FIN. Since FIN can technically also carry // data we verify that the segment carrying a - // FIN ends at exactly e.rcvNxt+1. + // FIN ends at exactly e.RcvNxt+1. // // From RFC793 page 25. // @@ -423,7 +412,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -435,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // end has closed and the peer is yet to send a FIN. Hence we // compare only the payload. segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + if rcvClosed && !segEnd.LessThanEq(r.RcvNxt) { return true, nil } return false, nil @@ -477,13 +466,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { - r.ep.rcvListMu.Lock() - r.pendingBufUsed += s.segMemSize() - r.ep.rcvListMu.Unlock() + if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed += s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) - UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.RcvNxt) } // Immediately send an ack so that the peer knows it may @@ -508,15 +497,15 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { segSeq := s.sequenceNumber // Skip segment altogether if it has already been acknowledged. - if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + if !segSeq.Add(segLen-1).LessThan(r.RcvNxt) && !r.consumeSegment(s, segSeq, segLen) { break } heap.Pop(&r.pendingRcvdSegments) - r.ep.rcvListMu.Lock() - r.pendingBufUsed -= s.segMemSize() - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed -= s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.decRef() } return false, nil @@ -558,7 +547,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } @@ -569,11 +558,11 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn } // Update Timestamp if required. See RFC7323, section-4.3. - if r.ep.sendTSOk && s.parsedOptions.TS { - r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + if r.ep.SendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() @@ -584,8 +573,8 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // carries data then just send an ACK. This is according to RFC 793, // page 37. // - // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. - if segSeq != r.rcvNxt || segLen != 0 { + // NOTE: In TIME_WAIT the only acceptable sequence number is RcvNxt. + if segSeq != r.RcvNxt || segLen != 0 { r.ep.snd.sendAck() } return false, false diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go index ff39780a5..063552c7f 100644 --- a/pkg/tcpip/transport/tcp/reno.go +++ b/pkg/tcpip/transport/tcp/reno.go @@ -34,14 +34,14 @@ func newRenoCC(s *sender) *renoState { func (r *renoState) updateSlowStart(packetsAcked int) int { // Don't let the congestion window cross into the congestion // avoidance range. - newcwnd := r.s.sndCwnd + packetsAcked - if newcwnd >= r.s.sndSsthresh { - newcwnd = r.s.sndSsthresh - r.s.sndCAAckCount = 0 + newcwnd := r.s.SndCwnd + packetsAcked + if newcwnd >= r.s.Ssthresh { + newcwnd = r.s.Ssthresh + r.s.SndCAAckCount = 0 } - packetsAcked -= newcwnd - r.s.sndCwnd - r.s.sndCwnd = newcwnd + packetsAcked -= newcwnd - r.s.SndCwnd + r.s.SndCwnd = newcwnd return packetsAcked } @@ -49,19 +49,19 @@ func (r *renoState) updateSlowStart(packetsAcked int) int { // avoidance mode as described in RFC5681 section 3.1 func (r *renoState) updateCongestionAvoidance(packetsAcked int) { // Consume the packets in congestion avoidance mode. - r.s.sndCAAckCount += packetsAcked - if r.s.sndCAAckCount >= r.s.sndCwnd { - r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd - r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd + r.s.SndCAAckCount += packetsAcked + if r.s.SndCAAckCount >= r.s.SndCwnd { + r.s.SndCwnd += r.s.SndCAAckCount / r.s.SndCwnd + r.s.SndCAAckCount = r.s.SndCAAckCount % r.s.SndCwnd } } // reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, // page 6, eq. 4. It is called when we detect congestion in the network. func (r *renoState) reduceSlowStartThreshold() { - r.s.sndSsthresh = r.s.outstanding / 2 - if r.s.sndSsthresh < 2 { - r.s.sndSsthresh = 2 + r.s.Ssthresh = r.s.Outstanding / 2 + if r.s.Ssthresh < 2 { + r.s.Ssthresh = 2 } } @@ -70,7 +70,7 @@ func (r *renoState) reduceSlowStartThreshold() { // were acknowledged. // Update implements congestionControl.Update. func (r *renoState) Update(packetsAcked int) { - if r.s.sndCwnd < r.s.sndSsthresh { + if r.s.SndCwnd < r.s.Ssthresh { packetsAcked = r.updateSlowStart(packetsAcked) if packetsAcked == 0 { return @@ -94,7 +94,7 @@ func (r *renoState) HandleRTOExpired() { // Reduce the congestion window to 1, i.e., enter slow-start. Per // RFC 5681, page 7, we must use 1 regardless of the value of the // initial congestion window. - r.s.sndCwnd = 1 + r.s.SndCwnd = 1 } // PostRecovery implements congestionControl.PostRecovery. diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go index 2aa708e97..d368a29fc 100644 --- a/pkg/tcpip/transport/tcp/reno_recovery.go +++ b/pkg/tcpip/transport/tcp/reno_recovery.go @@ -31,25 +31,25 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { snd := rr.s // We are in fast recovery mode. Ignore the ack if it's out of range. - if !ack.InRange(snd.sndUna, snd.sndNxt+1) { + if !ack.InRange(snd.SndUna, snd.SndNxt+1) { return } // Don't count this as a duplicate if it is carrying data or // updating the window. - if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window { + if rcvdSeg.logicalLen() != 0 || snd.SndWnd != rcvdSeg.window { return } // Inflate the congestion window if we're getting duplicate acks // for the packet we retransmitted. - if !fastRetransmit && ack == snd.fr.first { + if !fastRetransmit && ack == snd.FastRecovery.First { // We received a dup, inflate the congestion window by 1 packet // if we're not at the max yet. Only inflate the window if // regular FastRecovery is in use, RFC6675 does not require // inflating cwnd on duplicate ACKs. - if snd.sndCwnd < snd.fr.maxCwnd { - snd.sndCwnd++ + if snd.SndCwnd < snd.FastRecovery.MaxCwnd { + snd.SndCwnd++ } return } @@ -61,7 +61,7 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { // back onto the wire. // // N.B. The retransmit timer will be reset by the caller. - snd.fr.first = ack - snd.dupAckCount = 0 + snd.FastRecovery.First = ack + snd.DupAckCount = 0 snd.resendSegment() } diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go index 9d406b0bc..cd860b5e8 100644 --- a/pkg/tcpip/transport/tcp/sack_recovery.go +++ b/pkg/tcpip/transport/tcp/sack_recovery.go @@ -42,14 +42,14 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen } nextSegHint := snd.writeList.Front() - for snd.outstanding < snd.sndCwnd { + for snd.Outstanding < snd.SndCwnd { var nextSeg *segment var rescueRtx bool nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint) if nextSeg == nil { return dataSent } - if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + if !snd.isAssignedSequenceNumber(nextSeg) || snd.SndNxt.LessThanEq(nextSeg.sequenceNumber) { // New data being sent. // Step C.3 described below is handled by @@ -67,7 +67,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen return dataSent } dataSent = true - snd.outstanding++ + snd.Outstanding++ snd.writeNext = nextSeg.Next() continue } @@ -79,7 +79,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // "The estimate of the amount of data outstanding in the network // must be updated by incrementing pipe by the number of octets // transmitted in (C.1)." - snd.outstanding++ + snd.Outstanding++ dataSent = true snd.sendSegment(nextSeg) @@ -88,7 +88,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // We do the last part of rule (4) of NextSeg here to update // RescueRxt as until this point we don't know if we are going // to use the rescue transmission. - snd.fr.rescueRxt = snd.fr.last + snd.FastRecovery.RescueRxt = snd.FastRecovery.Last } else { // RFC 6675, Step C.2 // @@ -96,7 +96,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // HighData, HighRxt MUST be set to the highest sequence // number of the retransmitted segment unless NextSeg () // rule (4) was invoked for this retransmission." - snd.fr.highRxt = segEnd - 1 + snd.FastRecovery.HighRxt = segEnd - 1 } } return dataSent @@ -109,12 +109,12 @@ func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { } // We are in fast recovery mode. Ignore the ack if it's out of range. - if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) { + if ack := rcvdSeg.ackNumber; !ack.InRange(snd.SndUna, snd.SndNxt+1) { return } // RFC 6675 recovery algorithm step C 1-5. - end := snd.sndUna.Add(snd.sndWnd) - dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end) + end := snd.SndUna.Add(snd.SndWnd) + dataSent := sr.handleSACKRecovery(snd.MaxPayloadSize, end) snd.postXmit(dataSent, true /* shouldScheduleProbe */) } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 8edd6775b..c28641be3 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -236,20 +236,14 @@ func (s *segment) parse(skipChecksumValidation bool) bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - - verifyChecksum := true if skipChecksumValidation { s.csumValid = true - verifyChecksum = false - } - if verifyChecksum { + } else { s.csum = s.hdr.Checksum() - xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) - xsum = s.hdr.CalculateChecksum(xsum) - xsum = header.ChecksumVV(s.data, xsum) - s.csumValid = xsum == 0xffff + payloadChecksum := header.ChecksumVV(s.data, 0) + payloadLength := uint16(s.data.Size()) + s.csumValid = s.hdr.IsChecksumValid(s.srcAddr, s.dstAddr, payloadChecksum, payloadLength) } - s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) s.ackNumber = seqnum.Value(s.hdr.AckNumber()) s.flags = s.hdr.Flags() diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 54545a1b1..d0d1b0b8a 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -52,12 +52,12 @@ func (q *segmentQueue) empty() bool { func (q *segmentQueue) enqueue(s *segment) bool { // q.ep.receiveBufferParams() must be called without holding q.mu to // avoid lock order inversion. - bufSz := q.ep.receiveBufferSize() + bufSz := q.ep.ops.GetReceiveBufferSize() used := q.ep.receiveMemUsed() q.mu.Lock() // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue // is currently full). - allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + allow := (used <= int(bufSz) || s.payloadSize() == 0) && !q.frozen if allow { q.list.PushBack(s) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index faca35892..cf2e8dcd8 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -85,56 +86,12 @@ type lossRecovery interface { // // +stateify savable type sender struct { + stack.TCPSenderState ep *endpoint - // lastSendTime is the timestamp when the last packet was sent. - lastSendTime time.Time `state:".(unixTime)"` - - // dupAckCount is the number of duplicated acks received. It is used for - // fast retransmit. - dupAckCount int - - // fr holds state related to fast recovery. - fr fastRecovery - // lr is the loss recovery algorithm used by the sender. lr lossRecovery - // sndCwnd is the congestion window, in packets. - sndCwnd int - - // sndSsthresh is the threshold between slow start and congestion - // avoidance. - sndSsthresh int - - // sndCAAckCount is the number of packets acknowledged during congestion - // avoidance. When enough packets have been ack'd (typically cwnd - // packets), the congestion window is incremented by one. - sndCAAckCount int - - // outstanding is the number of outstanding packets, that is, packets - // that have been sent but not yet acknowledged. - outstanding int - - // sackedOut is the number of packets which are selectively acked. - sackedOut int - - // sndWnd is the send window size. - sndWnd seqnum.Size - - // sndUna is the next unacknowledged sequence number. - sndUna seqnum.Value - - // sndNxt is the sequence number of the next segment to be sent. - sndNxt seqnum.Value - - // rttMeasureSeqNum is the sequence number being used for the latest RTT - // measurement. - rttMeasureSeqNum seqnum.Value - - // rttMeasureTime is the time when the rttMeasureSeqNum was sent. - rttMeasureTime time.Time `state:".(unixTime)"` - // firstRetransmittedSegXmitTime is the original transmit time of // the first segment that was retransmitted due to RTO expiration. firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` @@ -147,17 +104,15 @@ type sender struct { // window probes. unackZeroWindowProbes uint32 `state:"nosave"` - closed bool writeNext *segment writeList segmentList resendTimer timer `state:"nosave"` resendWaker sleep.Waker `state:"nosave"` - // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", - // "round-trip time variation" and "retransmit timeout", as defined in + // rtt.TCPRTTState.SRTT and rtt.TCPRTTState.RTTVar are the "smoothed + // round-trip time", and "round-trip time variation", as defined in // section 2 of RFC 6298. rtt rtt - rto time.Duration // minRTO is the minimum permitted value for sender.rto. minRTO time.Duration @@ -168,20 +123,9 @@ type sender struct { // maxRetries is the maximum permitted retransmissions. maxRetries uint32 - // maxPayloadSize is the maximum size of the payload of a given segment. - // It is initialized on demand. - maxPayloadSize int - // gso is set if generic segmentation offload is enabled. gso bool - // sndWndScale is the number of bits to shift left when reading the send - // window size from a segment. - sndWndScale uint8 - - // maxSentAck is the maxium acknowledgement actually sent. - maxSentAck seqnum.Value - // state is the current state of congestion control for this endpoint. state tcpip.CongestionControlState @@ -209,41 +153,7 @@ type sender struct { type rtt struct { sync.Mutex `state:"nosave"` - srtt time.Duration - rttvar time.Duration - srttInited bool -} - -// fastRecovery holds information related to fast recovery from a packet loss. -// -// +stateify savable -type fastRecovery struct { - // active whether the endpoint is in fast recovery. The following fields - // are only meaningful when active is true. - active bool - - // first and last represent the inclusive sequence number range being - // recovered. - first seqnum.Value - last seqnum.Value - - // maxCwnd is the maximum value the congestion window may be inflated to - // due to duplicate acks. This exists to avoid attacks where the - // receiver intentionally sends duplicate acks to artificially inflate - // the sender's cwnd. - maxCwnd int - - // highRxt is the highest sequence number which has been retransmitted - // during the current loss recovery phase. - // See: RFC 6675 Section 2 for details. - highRxt seqnum.Value - - // rescueRxt is the highest sequence number which has been - // optimistically retransmitted to prevent stalling of the ACK clock - // when there is loss at the end of the window and no new data is - // available for transmission. - // See: RFC 6675 Section 2 for details. - rescueRxt seqnum.Value + stack.TCPRTTState } func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { @@ -253,20 +163,22 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint maxPayloadSize := int(mss) - ep.maxOptionSize() s := &sender{ - ep: ep, - sndWnd: sndWnd, - sndUna: iss + 1, - sndNxt: iss + 1, - rto: 1 * time.Second, - rttMeasureSeqNum: iss + 1, - lastSendTime: time.Now(), - maxPayloadSize: maxPayloadSize, - maxSentAck: irs + 1, - fr: fastRecovery{ - // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. - last: iss, - highRxt: iss, - rescueRxt: iss, + ep: ep, + TCPSenderState: stack.TCPSenderState{ + SndWnd: sndWnd, + SndUna: iss + 1, + SndNxt: iss + 1, + RTTMeasureSeqNum: iss + 1, + LastSendTime: time.Now(), + MaxPayloadSize: maxPayloadSize, + MaxSentAck: irs + 1, + FastRecovery: stack.TCPFastRecoveryState{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + Last: iss, + HighRxt: iss, + RescueRxt: iss, + }, + RTO: 1 * time.Second, }, gso: ep.gso != nil, } @@ -282,7 +194,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. if sndWndScale > 0 { - s.sndWndScale = uint8(sndWndScale) + s.SndWndScale = uint8(sndWndScale) } s.resendTimer.init(&s.resendWaker) @@ -294,7 +206,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // Initialize SACK Scoreboard after updating max payload size as we use // the maxPayloadSize as the smss when determining if a segment is lost // etc. - s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + s.ep.scoreboard = NewSACKScoreboard(uint16(s.MaxPayloadSize), iss) // Get Stack wide config. var minRTO tcpip.TCPMinRTOOption @@ -322,10 +234,10 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // returns a handle to it. It also initializes the sndCwnd and sndSsThresh to // their initial values. func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { - s.sndCwnd = InitialCwnd + s.SndCwnd = InitialCwnd // Set sndSsthresh to the maximum int value, which depends on the // platform. - s.sndSsthresh = int(^uint(0) >> 1) + s.Ssthresh = int(^uint(0) >> 1) switch congestionControlName { case ccCubic: @@ -339,7 +251,7 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon // initLossRecovery initiates the loss recovery algorithm for the sender. func (s *sender) initLossRecovery() lossRecovery { - if s.ep.sackPermitted { + if s.ep.SACKPermitted { return newSACKRecovery(s) } return newRenoRecovery(s) @@ -355,7 +267,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m -= s.ep.maxOptionSize() // We don't adjust up for now. - if m >= s.maxPayloadSize { + if m >= s.MaxPayloadSize { return } @@ -364,8 +276,8 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } - oldMSS := s.maxPayloadSize - s.maxPayloadSize = m + oldMSS := s.MaxPayloadSize + s.MaxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) } @@ -380,9 +292,9 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // maxPayloadSize. s.ep.scoreboard.smss = uint16(m) - s.outstanding -= count - if s.outstanding < 0 { - s.outstanding = 0 + s.Outstanding -= count + if s.Outstanding < 0 { + s.Outstanding = 0 } // Rewind writeNext to the first segment exceeding the MTU. Do nothing @@ -401,10 +313,10 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { nextSeg = seg } - if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + if s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Update sackedOut for new maximum payload size. - s.sackedOut -= s.pCount(seg, oldMSS) - s.sackedOut += s.pCount(seg, s.maxPayloadSize) + s.SackedOut -= s.pCount(seg, oldMSS) + s.SackedOut += s.pCount(seg, s.MaxPayloadSize) } } @@ -416,32 +328,32 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // sendAck sends an ACK segment. func (s *sender) sendAck() { - s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt) + s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.SndNxt) } // updateRTO updates the retransmit timeout when a new roud-trip time is // available. This is done in accordance with section 2 of RFC 6298. func (s *sender) updateRTO(rtt time.Duration) { s.rtt.Lock() - if !s.rtt.srttInited { - s.rtt.rttvar = rtt / 2 - s.rtt.srtt = rtt - s.rtt.srttInited = true + if !s.rtt.TCPRTTState.SRTTInited { + s.rtt.TCPRTTState.RTTVar = rtt / 2 + s.rtt.TCPRTTState.SRTT = rtt + s.rtt.TCPRTTState.SRTTInited = true } else { - diff := s.rtt.srtt - rtt + diff := s.rtt.TCPRTTState.SRTT - rtt if diff < 0 { diff = -diff } - // Use RFC6298 standard algorithm to update rttvar and srtt when + // Use RFC6298 standard algorithm to update TCPRTTState.RTTVar and TCPRTTState.SRTT when // no timestamps are available. - if !s.ep.sendTSOk { - s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4 - s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8 + if !s.ep.SendTSOk { + s.rtt.TCPRTTState.RTTVar = (3*s.rtt.TCPRTTState.RTTVar + diff) / 4 + s.rtt.TCPRTTState.SRTT = (7*s.rtt.TCPRTTState.SRTT + rtt) / 8 } else { // When we are taking RTT measurements of every ACK then // we need to use a modified method as specified in // https://tools.ietf.org/html/rfc7323#appendix-G - if s.outstanding == 0 { + if s.Outstanding == 0 { s.rtt.Unlock() return } @@ -449,7 +361,7 @@ func (s *sender) updateRTO(rtt time.Duration) { // terms of packets and not bytes. This is similar to // how linux also does cwnd and inflight. In practice // this approximation works as expected. - expectedSamples := math.Ceil(float64(s.outstanding) / 2) + expectedSamples := math.Ceil(float64(s.Outstanding) / 2) // alpha & beta values are the original values as recommended in // https://tools.ietf.org/html/rfc6298#section-2.3. @@ -458,17 +370,17 @@ func (s *sender) updateRTO(rtt time.Duration) { alphaPrime := alpha / expectedSamples betaPrime := beta / expectedSamples - rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds() - srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds() - s.rtt.rttvar = time.Duration(rttVar * float64(time.Second)) - s.rtt.srtt = time.Duration(srtt * float64(time.Second)) + rttVar := (1-betaPrime)*s.rtt.TCPRTTState.RTTVar.Seconds() + betaPrime*diff.Seconds() + srtt := (1-alphaPrime)*s.rtt.TCPRTTState.SRTT.Seconds() + alphaPrime*rtt.Seconds() + s.rtt.TCPRTTState.RTTVar = time.Duration(rttVar * float64(time.Second)) + s.rtt.TCPRTTState.SRTT = time.Duration(srtt * float64(time.Second)) } } - s.rto = s.rtt.srtt + 4*s.rtt.rttvar + s.RTO = s.rtt.TCPRTTState.SRTT + 4*s.rtt.TCPRTTState.RTTVar s.rtt.Unlock() - if s.rto < s.minRTO { - s.rto = s.minRTO + if s.RTO < s.minRTO { + s.RTO = s.minRTO } } @@ -476,20 +388,20 @@ func (s *sender) updateRTO(rtt time.Duration) { func (s *sender) resendSegment() { // Don't use any segments we already sent to measure RTT as they may // have been affected by packets being lost. - s.rttMeasureSeqNum = s.sndNxt + s.RTTMeasureSeqNum = s.SndNxt // Resend the segment. if seg := s.writeList.Front(); seg != nil { - if seg.data.Size() > s.maxPayloadSize { - s.splitSeg(seg, s.maxPayloadSize) + if seg.data.Size() > s.MaxPayloadSize { + s.splitSeg(seg, s.MaxPayloadSize) } // See: RFC 6675 section 5 Step 4.3 // // To prevent retransmission, set both the HighRXT and RescueRXT // to the highest sequence number in the retransmitted segment. - s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 - s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.FastRecovery.HighRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.FastRecovery.RescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() s.ep.stats.SendErrors.FastRetransmit.Increment() @@ -554,15 +466,15 @@ func (s *sender) retransmitTimerExpired() bool { // Set new timeout. The timer will be restarted by the call to sendData // below. - s.rto *= 2 + s.RTO *= 2 // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5 - if s.rto > s.maxRTO { - s.rto = s.maxRTO + if s.RTO > s.maxRTO { + s.RTO = s.maxRTO } // Cap RTO to remaining time. - if s.rto > remaining { - s.rto = remaining + if s.RTO > remaining { + s.RTO = remaining } // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. @@ -571,9 +483,9 @@ func (s *sender) retransmitTimerExpired() bool { // After a retransmit timeout, record the highest sequence number // transmitted in the variable recover, and exit the fast recovery // procedure if applicable. - s.fr.last = s.sndNxt - 1 + s.FastRecovery.Last = s.SndNxt - 1 - if s.fr.active { + if s.FastRecovery.Active { // We were attempting fast recovery but were not successful. // Leave the state. We don't need to update ssthresh because it // has already been updated when entered fast-recovery. @@ -589,7 +501,7 @@ func (s *sender) retransmitTimerExpired() bool { // // We'll keep on transmitting (or retransmitting) as we get acks for // the data we transmit. - s.outstanding = 0 + s.Outstanding = 0 // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1 // @@ -663,7 +575,7 @@ func (s *sender) splitSeg(seg *segment, size int) { // window space. // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() - if seg.data.Size() > s.maxPayloadSize { + if seg.data.Size() > s.MaxPayloadSize { seg.flags ^= header.TCPFlagPsh } @@ -689,7 +601,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // transmitted (i.e. either it has no assigned sequence number // or if it does have one, it's >= the next sequence number // to be sent [i.e. >= s.sndNxt]). - if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) { + if !s.isAssignedSequenceNumber(seg) || s.SndNxt.LessThanEq(seg.sequenceNumber) { hint = nil break } @@ -710,7 +622,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // (1.a) S2 is greater than HighRxt // (1.b) S2 is less than highest octect covered by // any received SACK. - if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { + if s.FastRecovery.HighRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { // NextSeg(): // (1.c) IsLost(S2) returns true. if s.ep.scoreboard.IsLost(segSeq) { @@ -743,7 +655,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // unSACKed sequence number SHOULD be returned, and // RescueRxt set to RecoveryPoint. HighRxt MUST NOT // be updated. - if s.fr.rescueRxt.LessThan(s.sndUna - 1) { + if s.FastRecovery.RescueRxt.LessThan(s.SndUna - 1) { if s4 != nil { if s4.sequenceNumber.LessThan(segSeq) { s4 = seg @@ -763,7 +675,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // previously unsent data starting with sequence number // HighData+1 MUST be returned." for seg := s.writeNext; seg != nil; seg = seg.Next() { - if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.SndNxt) { continue } // We do not split the segment here to <= smss as it has @@ -788,7 +700,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(s.sndNxt.Size(end)) + available := int(s.SndNxt.Size(end)) if available > limit { available = limit } @@ -816,7 +728,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } if !nextTooBig && seg.data.Size() < available { // Segment is not full. - if s.outstanding > 0 && s.ep.ops.GetDelayOption() { + if s.Outstanding > 0 && s.ep.ops.GetDelayOption() { // Nagle's algorithm. From Wikipedia: // Nagle's algorithm works by // combining a number of small @@ -835,7 +747,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // send space and MSS. // TODO(gvisor.dev/issue/2833): Drain the held segments after a // timeout. - if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() { + if seg.data.Size() < s.MaxPayloadSize && s.ep.ops.GetCorkOption() { return false } } @@ -843,7 +755,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // Assign flags. We don't do it above so that we can merge // additional data if Nagle holds the segment. - seg.sequenceNumber = s.sndNxt + seg.sequenceNumber = s.SndNxt seg.flags = header.TCPFlagAck | header.TCPFlagPsh } @@ -893,12 +805,12 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // the segment right here if there are no pending segments. If // there are pending segments, segment transmits are deferred to // the retransmit timer handler. - if s.sndUna != s.sndNxt { + if s.SndUna != s.SndNxt { switch { case available >= seg.data.Size(): // OK to send, the whole segments fits in the // receiver's advertised window. - case available >= s.maxPayloadSize: + case available >= s.MaxPayloadSize: // OK to send, at least 1 MSS sized segment fits // in the receiver's advertised window. default: @@ -918,8 +830,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // If GSO is not in use then cap available to // maxPayloadSize. When GSO is in use the gVisor GSO logic or // the host GSO logic will cap the segment to the correct size. - if s.ep.gso == nil && available > s.maxPayloadSize { - available = s.maxPayloadSize + if s.ep.gso == nil && available > s.MaxPayloadSize { + available = s.MaxPayloadSize } if seg.data.Size() > available { @@ -933,8 +845,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // Update sndNxt if we actually sent new data (as opposed to // retransmitting some previously sent data). - if s.sndNxt.LessThan(segEnd) { - s.sndNxt = segEnd + if s.SndNxt.LessThan(segEnd) { + s.SndNxt = segEnd } return true @@ -945,9 +857,9 @@ func (s *sender) sendZeroWindowProbe() { s.unackZeroWindowProbes++ // Send a zero window probe with sequence number pointing to // the last acknowledged byte. - s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win) + s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.SndUna-1, ack, win) // Rearm the timer to continue probing. - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } func (s *sender) enableZeroWindowProbing() { @@ -958,7 +870,7 @@ func (s *sender) enableZeroWindowProbing() { if s.firstRetransmittedSegXmitTime.IsZero() { s.firstRetransmittedSegXmitTime = time.Now() } - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } func (s *sender) disableZeroWindowProbing() { @@ -978,12 +890,12 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { // If the sender has advertized zero receive window and we have // data to be sent out, start zero window probing to query the // the remote for it's receive window size. - if s.writeNext != nil && s.sndWnd == 0 { + if s.writeNext != nil && s.SndWnd == 0 { s.enableZeroWindowProbing() } // If we have no more pending data, start the keepalive timer. - if s.sndUna == s.sndNxt { + if s.SndUna == s.SndNxt { s.ep.resetKeepaliveTimer(false) } else { // Enable timers if we have pending data. @@ -992,10 +904,10 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { s.schedulePTO() } else if !s.resendTimer.enabled() { s.probeTimer.disable() - if s.outstanding > 0 { + if s.Outstanding > 0 { // Enable the resend timer if it's not enabled yet and there is // outstanding data. - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } } } @@ -1004,29 +916,29 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { - limit := s.maxPayloadSize + limit := s.MaxPayloadSize if s.gso { limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) } - end := s.sndUna.Add(s.sndWnd) + end := s.SndUna.Add(s.SndWnd) // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { - if s.sndCwnd > InitialCwnd { - s.sndCwnd = InitialCwnd + if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO { + if s.SndCwnd > InitialCwnd { + s.SndCwnd = InitialCwnd } } var dataSent bool - for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { - cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + for seg := s.writeNext; seg != nil && s.Outstanding < s.SndCwnd; seg = seg.Next() { + cwndLimit := (s.SndCwnd - s.Outstanding) * s.MaxPayloadSize if cwndLimit < limit { limit = cwndLimit } - if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + if s.isAssignedSequenceNumber(seg) && s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Move writeNext along so that we don't try and scan data that // has already been SACKED. s.writeNext = seg.Next() @@ -1036,7 +948,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding += s.pCount(seg, s.maxPayloadSize) + s.Outstanding += s.pCount(seg, s.MaxPayloadSize) s.writeNext = seg.Next() } @@ -1044,21 +956,21 @@ func (s *sender) sendData() { } func (s *sender) enterRecovery() { - s.fr.active = true + s.FastRecovery.Active = true // Save state to reflect we're now in fast recovery. // // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. - s.sndCwnd = s.sndSsthresh + 3 - s.sackedOut = 0 - s.dupAckCount = 0 - s.fr.first = s.sndUna - s.fr.last = s.sndNxt - 1 - s.fr.maxCwnd = s.sndCwnd + s.outstanding - s.fr.highRxt = s.sndUna - s.fr.rescueRxt = s.sndUna - if s.ep.sackPermitted { + s.SndCwnd = s.Ssthresh + 3 + s.SackedOut = 0 + s.DupAckCount = 0 + s.FastRecovery.First = s.SndUna + s.FastRecovery.Last = s.SndNxt - 1 + s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding + s.FastRecovery.HighRxt = s.SndUna + s.FastRecovery.RescueRxt = s.SndUna + if s.ep.SACKPermitted { s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() // Set TLPRxtOut to false according to @@ -1075,12 +987,12 @@ func (s *sender) enterRecovery() { } func (s *sender) leaveRecovery() { - s.fr.active = false - s.fr.maxCwnd = 0 - s.dupAckCount = 0 + s.FastRecovery.Active = false + s.FastRecovery.MaxCwnd = 0 + s.DupAckCount = 0 // Deflate cwnd. It had been artificially inflated when new dups arrived. - s.sndCwnd = s.sndSsthresh + s.SndCwnd = s.Ssthresh s.cc.PostRecovery() } @@ -1099,7 +1011,7 @@ func (s *sender) isAssignedSequenceNumber(seg *segment) bool { func (s *sender) SetPipe() { // If SACK isn't permitted or it is permitted but recovery is not active // then ignore pipe calculations. - if !s.ep.sackPermitted || !s.fr.active { + if !s.ep.SACKPermitted || !s.FastRecovery.Active { return } pipe := 0 @@ -1119,7 +1031,7 @@ func (s *sender) SetPipe() { // After initializing pipe to zero, the following steps are // taken for each octet 'S1' in the sequence space between // HighACK and HighData that has not been SACKed: - if !s1.sequenceNumber.LessThan(s.sndNxt) { + if !s1.sequenceNumber.LessThan(s.SndNxt) { break } if s.ep.scoreboard.IsSACKED(sb) { @@ -1138,20 +1050,20 @@ func (s *sender) SetPipe() { } // SetPipe(): // (b) If S1 <= HighRxt, Pipe is incremented by 1. - if s1.sequenceNumber.LessThanEq(s.fr.highRxt) { + if s1.sequenceNumber.LessThanEq(s.FastRecovery.HighRxt) { pipe++ } } } - s.outstanding = pipe + s.Outstanding = pipe } // shouldEnterRecovery returns true if the sender should enter fast recovery // based on dupAck count and sack scoreboard. // See RFC 6675 section 5. func (s *sender) shouldEnterRecovery() bool { - return s.dupAckCount >= nDupAckThreshold || - (s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.sndUna)) + return s.DupAckCount >= nDupAckThreshold || + (s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.SndUna)) } // detectLoss is called when an ack is received and returns whether a loss is @@ -1163,24 +1075,24 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // If RACK is enabled and there is no reordering we should honor the // three duplicate ACK rule to enter recovery. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-4 - if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { - if s.rc.reorderSeen { + if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.rc.Reord { return false } } if !s.isDupAck(seg) { - s.dupAckCount = 0 + s.DupAckCount = 0 return false } - s.dupAckCount++ + s.DupAckCount++ // Do not enter fast recovery until we reach nDupAckThreshold or the // first unacknowledged byte is considered lost as per SACK scoreboard. if !s.shouldEnterRecovery() { // RFC 6675 Step 3. - s.fr.highRxt = s.sndUna - 1 + s.FastRecovery.HighRxt = s.SndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() s.state = tcpip.Disorder @@ -1196,8 +1108,8 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // Note that we only enter recovery when at least one more byte of data // beyond s.fr.last (the highest byte that was outstanding when fast // retransmit was last entered) is acked. - if !s.fr.last.LessThan(seg.ackNumber - 1) { - s.dupAckCount = 0 + if !s.FastRecovery.Last.LessThan(seg.ackNumber - 1) { + s.DupAckCount = 0 return false } s.cc.HandleLossDetected() @@ -1212,22 +1124,22 @@ func (s *sender) isDupAck(seg *segment) bool { // can leverage the SACK information to determine when an incoming ACK is a // "duplicate" (e.g., if the ACK contains previously unknown SACK // information). - if s.ep.sackPermitted && !seg.hasNewSACKInfo { + if s.ep.SACKPermitted && !seg.hasNewSACKInfo { return false } // (a) The receiver of the ACK has outstanding data. - return s.sndUna != s.sndNxt && + return s.SndUna != s.SndNxt && // (b) The incoming acknowledgment carries no data. seg.logicalLen() == 0 && // (c) The SYN and FIN bits are both off. !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && // (d) the ACK number is equal to the greatest acknowledgment received on // the given connection (TCP.UNA from RFC793). - seg.ackNumber == s.sndUna && + seg.ackNumber == s.SndUna && // (e) the advertised window in the incoming acknowledgment equals the // advertised window in the last incoming acknowledgment. - s.sndWnd == seg.window + s.SndWnd == seg.window } // Iterate the writeList and update RACK for each segment which is newly acked @@ -1267,7 +1179,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) seg.acked = true - s.sackedOut += s.pCount(seg, s.maxPayloadSize) + s.SackedOut += s.pCount(seg, s.MaxPayloadSize) } seg = seg.Next() } @@ -1322,18 +1234,18 @@ func checkDSACK(rcvdSeg *segment) bool { // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. - if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { - s.updateRTO(time.Now().Sub(s.rttMeasureTime)) - s.rttMeasureSeqNum = s.sndNxt + if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { + s.updateRTO(time.Now().Sub(s.RTTMeasureTime)) + s.RTTMeasureSeqNum = s.SndNxt } // Update Timestamp if required. See RFC7323, section-4.3. - if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS { - s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber) + if s.ep.SendTSOk && rcvdSeg.parsedOptions.TS { + s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.MaxSentAck, rcvdSeg.sequenceNumber) } // Insert SACKBlock information into our scoreboard. - if s.ep.sackPermitted { + if s.ep.SACKPermitted { for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds // true: @@ -1347,7 +1259,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // NOTE: This check specifically excludes DSACK blocks // which have start/end before sndUna and are used to // indicate spurious retransmissions. - if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + if rcvdSeg.ackNumber.LessThan(sb.Start) && s.SndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.SndNxt) && !s.ep.scoreboard.IsSACKED(sb) { s.ep.scoreboard.Insert(sb) rcvdSeg.hasNewSACKInfo = true } @@ -1375,10 +1287,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ack := rcvdSeg.ackNumber fastRetransmit := false // Do not leave fast recovery, if the ACK is out of range. - if s.fr.active { + if s.FastRecovery.Active { // Leave fast recovery if it acknowledges all the data covered by // this fast recovery session. - if (ack-1).InRange(s.sndUna, s.sndNxt) && s.fr.last.LessThan(ack) { + if (ack-1).InRange(s.SndUna, s.SndNxt) && s.FastRecovery.Last.LessThan(ack) { s.leaveRecovery() } } else { @@ -1392,28 +1304,28 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Stash away the current window size. - s.sndWnd = rcvdSeg.window + s.SndWnd = rcvdSeg.window // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously // unacknowledged segment. if s.zeroWindowProbing && rcvdSeg.window > 0 && - (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { + (ack == s.SndUna || (ack-1).InRange(s.SndUna, s.SndNxt)) { s.disableZeroWindowProbing() } // On receiving the ACK for the zero window probe, account for it and // skip trying to send any segment as we are still probing for // receive window to become non-zero. - if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna { + if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.SndUna { s.unackZeroWindowProbes-- return } // Ignore ack if it doesn't acknowledge any new data. - if (ack - 1).InRange(s.sndUna, s.sndNxt) { - s.dupAckCount = 0 + if (ack - 1).InRange(s.SndUna, s.SndNxt) { + s.DupAckCount = 0 // See : https://tools.ietf.org/html/rfc1323#section-3.3. // Specifically we should only update the RTO using TSEcr if the @@ -1423,7 +1335,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // averaged RTT measurement only if the segment acknowledges // some new data, i.e., only if it advances the left edge of // the send window. - if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { + if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { // TSVal/Ecr values sent by Netstack are at a millisecond // granularity. elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond @@ -1438,12 +1350,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // When an ack is received we must rearm the timer. // RFC 6298 5.3 s.probeTimer.disable() - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } // Remove all acknowledged data from the write list. - acked := s.sndUna.Size(ack) - s.sndUna = ack + acked := s.SndUna.Size(ack) + s.SndUna = ack // The remote ACK-ing at least 1 byte is an indication that we have a // full-duplex connection to the remote as the only way we will receive an @@ -1457,7 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } ackLeft := acked - originalOutstanding := s.outstanding + originalOutstanding := s.Outstanding for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1466,10 +1378,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg, s.maxPayloadSize) + prevCount := s.pCount(seg, s.MaxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) + s.Outstanding -= prevCount - s.pCount(seg, s.MaxPayloadSize) break } @@ -1478,7 +1390,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Update the RACK fields if SACK is enabled. - if s.ep.sackPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.ep.SACKPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1488,10 +1400,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). - if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg, s.maxPayloadSize) + if !s.ep.SACKPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + s.Outstanding -= s.pCount(seg, s.MaxPayloadSize) } else { - s.sackedOut -= s.pCount(seg, s.maxPayloadSize) + s.SackedOut -= s.pCount(seg, s.MaxPayloadSize) } seg.decRef() ackLeft -= datalen @@ -1501,13 +1413,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.ep.updateSndBufferUsage(int(acked)) // Clear SACK information for all acked data. - s.ep.scoreboard.Delete(s.sndUna) + s.ep.scoreboard.Delete(s.SndUna) // If we are not in fast recovery then update the congestion // window based on the number of acknowledged packets. - if !s.fr.active { - s.cc.Update(originalOutstanding - s.outstanding) - if s.fr.last.LessThan(s.sndUna) { + if !s.FastRecovery.Active { + s.cc.Update(originalOutstanding - s.Outstanding) + if s.FastRecovery.Last.LessThan(s.SndUna) { s.state = tcpip.Open // Update RACK when we are exiting fast or RTO // recovery as described in the RFC @@ -1522,16 +1434,16 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // It is possible for s.outstanding to drop below zero if we get // a retransmit timeout, reset outstanding to zero but later // get an ack that cover previously sent data. - if s.outstanding < 0 { - s.outstanding = 0 + if s.Outstanding < 0 { + s.Outstanding = 0 } s.SetPipe() // If all outstanding data was acknowledged the disable the timer. // RFC 6298 Rule 5.3 - if s.sndUna == s.sndNxt { - s.outstanding = 0 + if s.SndUna == s.SndNxt { + s.Outstanding = 0 // Reset firstRetransmittedSegXmitTime to the zero value. s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() @@ -1539,7 +1451,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } - if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { // Update RACK reorder window. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: @@ -1549,7 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the // reorder window. - if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.fr.active { + if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.FastRecovery.Active { // If any segment is marked as lost by // RACK, enter recovery and retransmit // the lost segments. @@ -1558,19 +1470,19 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { fastRetransmit = true } - if s.fr.active { + if s.FastRecovery.Active { s.rc.DoRecovery(nil, fastRetransmit) } } // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. - if s.fr.active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { + if s.FastRecovery.Active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { s.lr.DoRecovery(rcvdSeg, fastRetransmit) // When SACK is enabled data sending is governed by steps in // RFC 6675 Section 5 recovery steps A-C. // See: https://tools.ietf.org/html/rfc6675#section-5. - if s.ep.sackPermitted { + if s.ep.SACKPermitted { return } } @@ -1587,7 +1499,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() - if s.sndCwnd < s.sndSsthresh { + if s.SndCwnd < s.Ssthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } } @@ -1601,11 +1513,11 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // then use the conservative timer described in RFC6675 Section 6.0, // otherwise follow the standard time described in RFC6298 Section 5.1. if err != nil && seg.data.Size() != 0 { - if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted { - s.resendTimer.enable(s.rto) + if s.FastRecovery.Active && seg.xmitCount > 1 && s.ep.SACKPermitted { + s.resendTimer.enable(s.RTO) } else { if !s.resendTimer.enabled() { - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } } } @@ -1616,15 +1528,15 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { - s.lastSendTime = time.Now() - if seq == s.rttMeasureSeqNum { - s.rttMeasureTime = s.lastSendTime + s.LastSendTime = time.Now() + if seq == s.RTTMeasureSeqNum { + s.RTTMeasureTime = s.LastSendTime } rcvNxt, rcvWnd := s.ep.rcv.getSendParams() // Remember the max sent ack. - s.maxSentAck = rcvNxt + s.MaxSentAck = rcvNxt return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index ba41cff6d..2f805d8ce 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -24,26 +24,6 @@ type unixTime struct { nano int64 } -// saveLastSendTime is invoked by stateify. -func (s *sender) saveLastSendTime() unixTime { - return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (s *sender) loadLastSendTime(unix unixTime) { - s.lastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRttMeasureTime is invoked by stateify. -func (s *sender) saveRttMeasureTime() unixTime { - return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()} -} - -// loadRttMeasureTime is invoked by stateify. -func (s *sender) loadRttMeasureTime(unix unixTime) { - s.rttMeasureTime = time.Unix(unix.second, unix.nano) -} - // afterLoad is invoked by stateify. func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 5cdd5b588..c58361bc1 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -33,6 +33,7 @@ const ( tsOptionSize = 12 maxTCPOptionSize = 40 mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload + latency = 5 * time.Millisecond ) func setStackRACKPermitted(t *testing.T, c *context.Context) { @@ -182,6 +183,9 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en for i := 0; i < numPackets; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload + // This delay is added to increase RTT as low RTT can cause TLP + // before sending ACK. + time.Sleep(latency) } return data @@ -479,7 +483,7 @@ func TestRACKOnePacketTailLoss(t *testing.T) { }{ // #3 was retransmitted as TLP. {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, // RTO should not have fired. {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, @@ -852,8 +856,8 @@ func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan return } - if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { - probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) + if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT { + probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT) return } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 81f800cad..20c9761f2 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -160,12 +160,9 @@ func TestSackPermittedAccept(t *testing.T) { defer c.Cleanup() if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -235,12 +232,9 @@ func TestSackDisabledAccept(t *testing.T) { defer c.Cleanup() if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9c23469f2..9f29a48fb 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" + tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" "gvisor.dev/gvisor/pkg/test/testutil" @@ -929,10 +930,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { } // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) - } + rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} @@ -955,11 +953,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { // when completing the handshake for a new TCP connection from a TCP // listening socket. It should be present in the sent TCP SYN-ACK segment. func TestUserSuppliedMSSOnListenAccept(t *testing.T) { - const ( - nonSynCookieAccepts = 2 - totalAccepts = 4 - mtu = 5000 - ) + const mtu = 5000 ips := []struct { name string @@ -1033,12 +1027,6 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { ip.createEP(c) - // Set the SynRcvd threshold to force a syn cookie based accept to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) } @@ -1048,13 +1036,17 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { t.Fatalf("Bind(%+v): %s:", bindAddr, err) } - if err := c.EP.Listen(totalAccepts); err != nil { - t.Fatalf("Listen(%d): %s:", totalAccepts, err) + backlog := 5 + // Keep the number of client requests twice to the backlog + // such that half of the connections do not use syncookies + // and the other half does. + clientConnects := backlog * 2 + + if err := c.EP.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s:", backlog, err) } - // The first nonSynCookieAccepts packets sent will trigger a gorooutine - // based accept. The rest will trigger a cookie based accept. - for i := 0; i < totalAccepts; i++ { + for i := 0; i < clientConnects; i++ { // Send a SYN requests. iss := seqnum.Value(i) srcPort := context.TestPort + uint16(i) @@ -1297,6 +1289,98 @@ func TestListenShutdown(t *testing.T) { )) } +var _ waiter.EntryCallback = (callback)(nil) + +type callback func(*waiter.Entry, waiter.EventMask) + +func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) { + cb(entry, mask) +} + +func TestListenerReadinessOnEvent(t *testing.T) { + s := stack.New(stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + { + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } + const id = 1 + if err := s.CreateNIC(id, ep); err != nil { + t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) + } + if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { + t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: header.IPv4EmptySubnet, NIC: id}, + }) + } + + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { + t.Fatalf("Bind(%s): %s", context.StackAddr, err) + } + const backlog = 1 + if err := ep.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s", backlog, err) + } + + address, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress(): %s", err) + } + + conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer conn.Close() + + events := make(chan waiter.EventMask) + // Scope `entry` to allow a binding of the same name below. + { + entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) { + events <- ep.Readiness(mask) + })} + wq.EventRegister(&entry, waiter.EventIn) + defer wq.EventUnregister(&entry) + } + + entry, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&entry, waiter.EventOut) + defer wq.EventUnregister(&entry) + + switch err := conn.Connect(address).(type) { + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect(%#v): %v", address, err) + } + + // Read at least one event. + got := <-events + for { + select { + case event := <-events: + got |= event + continue + case <-ch: + if want := waiter.ReadableEvents; got != want { + t.Errorf("observed events = %b, want %b", got, want) + } + } + break + } +} + // TestListenCloseWhileConnect tests for the listening endpoint to // drain the accept-queue when closed. This should reset all of the // pending connections that are waiting to be accepted. @@ -1993,9 +2077,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Bump up the receive buffer size such that, when the receive window grows, // the scaled window exceeds maxUint16. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true) // Keep the payload size < segment overhead and such that it is a multiple // of the window scaled value. This enables the test to perform equality @@ -2115,9 +2197,7 @@ func TestNoWindowShrinking(t *testing.T) { initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) // Now shrink the receive buffer to half its original size. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true) data := generateRandomPayload(t, rcvBufSize) // Send a payload of half the size of rcvBufSize. @@ -2373,9 +2453,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } + ep.SocketOptions().SetReceiveBufferSize(65535*3, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -2447,9 +2525,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } + ep.SocketOptions().SetReceiveBufferSize(65535*3, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -3042,9 +3118,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } + ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -3087,11 +3161,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, mtu) defer c.Cleanup() - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(0) + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -3185,9 +3257,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 3 - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } + c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) @@ -4411,11 +4481,7 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) - } - + s := ep.SocketOptions().GetReceiveBufferSize() if int(s) != v { t.Fatalf("got receive buffer size = %d, want = %d", s, v) } @@ -4521,10 +4587,7 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min/2. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) - } - + ep.SocketOptions().SetReceiveBufferSize(99, true) checkRecvBufferSize(t, ep, 200) ep.SocketOptions().SetSendBufferSize(149, true) @@ -4532,15 +4595,11 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - + ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true) // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) - // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) } @@ -4814,7 +4873,13 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("unknown address type: '%s'", candidateAddressType) } - start, end := s.PortRange() + const ( + start = 16000 + end = 16050 + ) + if err := s.SetPortRange(start, end); err != nil { + t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) + } for i := start; i <= end; i++ { if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { t.Fatalf("Bind(%d) failed: %s", i, err) @@ -5363,7 +5428,7 @@ func TestListenBacklogFull(t *testing.T) { } lastPortOffset := uint16(0) - for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { + for ; int(lastPortOffset) < listenBacklog+1; lastPortOffset++ { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } @@ -5445,8 +5510,8 @@ func TestListenBacklogFull(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a // non unicast IPv4 address are not accepted. func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpip.Address("\xe0\x00\x01\x02") - otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + multicastAddr := tcpiptestutil.MustParse4("224.0.1.2") + otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3") subnet := context.StackAddrWithPrefix.Subnet() subnetBroadcastAddr := subnet.Broadcast() @@ -5557,8 +5622,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + multicastAddr := tcpiptestutil.MustParse6("ff0e::101") + otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102") tests := []struct { name string @@ -5671,15 +5736,13 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } // Test acceptance. - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { + if err := c.EP.Listen(0); err != nil { t.Fatalf("Listen failed: %s", err) } // Send two SYN's the first one should get a SYN-ACK, the // second one should not get any response and is dropped as - // the synRcvd count will be equal to backlog. + // the accept queue is full. irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5701,23 +5764,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - // - // NOTE: we did not complete the handshake for the previous one so the - // accept backlog should be empty and there should be one connection in - // synRcvd state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Now complete the previous connection and verify that there is a connection - // to accept. + // Now complete the previous connection. // Send ACK. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5728,11 +5775,24 @@ func TestListenSynRcvdQueueFull(t *testing.T) { RcvWnd: 30000, }) - // Try to accept the connections in the backlog. + // Verify if that is delivered to the accept queue. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.ReadableEvents) defer c.WQ.EventUnregister(&we) + <-ch + + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(889), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + // Try to accept the connections in the backlog. newEP, _, err := c.EP.Accept(nil) if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. @@ -5764,11 +5824,6 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.TCPSynRcvdCountThresholdOption(1) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - // Create TCP endpoint. var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -5781,9 +5836,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { t.Fatalf("Bind failed: %s", err) } - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { + // Test for SynCookies usage after filling up the backlog. + if err := c.EP.Listen(0); err != nil { t.Fatalf("Listen failed: %s", err) } @@ -6066,7 +6120,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } - if err := c.EP.Listen(1); err != nil { + if err := c.EP.Listen(0); err != nil { t.Fatalf("Listen failed: %s", err) } @@ -7553,8 +7607,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS - c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - + c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 2949588ce..1deb1fe4d 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -139,9 +139,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -202,9 +202,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index e73f90bb0..7578d64ec 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -757,9 +757,7 @@ func (c *Context) Create(epRcvBuf int) { } if epRcvBuf != -1 { - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */) } } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 153e8c950..dd5c910ae 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -56,6 +56,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 956da0e0c..c9f2f3efc 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "fmt" "io" "sync/atomic" @@ -89,12 +88,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList udpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList udpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -144,6 +142,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // +stateify savable @@ -173,14 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - rcvBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -188,9 +190,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } return e @@ -622,26 +624,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { e.mu.Lock() e.sendTOS = uint8(v) e.mu.Unlock() - - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) - } - - if v < rs.Min { - v = rs.Min - } - if v > rs.Max { - v = rs.Max - } - - e.mu.Lock() - e.rcvBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -802,12 +784,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.mu.Lock() v := int(e.ttl) @@ -1255,20 +1231,29 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // verifyChecksum verifies the checksum unless RX checksum offload is enabled. -// On IPv4, UDP checksum is optional, and a zero value means the transmitter -// omitted the checksum generation (RFC768). -// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { - if !pkt.RXTransportChecksumValidated && - (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { - netHdr := pkt.Network() - xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) - for _, v := range pkt.Data().Views() { - xsum = header.Checksum(v, xsum) - } - return hdr.CalculateChecksum(xsum) == 0xffff + if pkt.RXTransportChecksumValidated { + return true + } + + // On IPv4, UDP checksum is optional, and a zero value means the transmitter + // omitted the checksum generation, as per RFC 768: + // + // An all zero transmitted checksum value means that the transmitter + // generated no checksum (for debugging or for higher level protocols that + // don't care). + // + // On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1: + // + // Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP + // checksum is not optional. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 { + return true } - return true + + netHdr := pkt.Network() + payloadChecksum := pkt.Data().AsRange().Checksum() + return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum) } // HandlePacket is called by the stack when new packets arrive to this transport @@ -1284,7 +1269,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } if !verifyChecksum(hdr, pkt) { - // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() return @@ -1302,7 +1286,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -1436,3 +1421,18 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 21a6aa460..4aba68b21 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,43 +37,25 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() + e.mu.Lock() defer e.mu.Unlock() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 77ca70a04..dc2e3f493 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -34,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -2364,7 +2365,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } ipv4Subnet := ipv4Addr.Subnet() ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4Gateway := testutil.MustParse4("192.168.1.1") ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ Address: "\xc0\xa8\x01\x3a", PrefixLen: 31, diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD index 7f983a0b3..366f068e3 100644 --- a/pkg/test/dockerutil/BUILD +++ b/pkg/test/dockerutil/BUILD @@ -36,8 +36,8 @@ go_test( tags = [ # Requires docker and runsc to be configured before test runs. # Also requires the test to be run as root. - "manual", "local", + "manual", ], visibility = ["//:sandbox"], ) diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 41fcf4978..06152a444 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -434,7 +434,14 @@ func (c *Container) Wait(ctx context.Context) error { select { case err := <-errChan: return err - case <-statusChan: + case res := <-statusChan: + if res.StatusCode != 0 { + var msg string + if res.Error != nil { + msg = res.Error.Message + } + return fmt.Errorf("container returned non-zero status: %d, msg: %q", res.StatusCode, msg) + } return nil } } diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD index 054269b59..3dba36f12 100644 --- a/pkg/usermem/BUILD +++ b/pkg/usermem/BUILD @@ -1,42 +1,22 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -go_template_instance( - name = "addr_range", - out = "addr_range.go", - package = "usermem", - prefix = "Addr", - template = "//pkg/segment:generic_range", - types = { - "T": "Addr", - }, -) - go_library( name = "usermem", srcs = [ - "access_type.go", - "addr.go", - "addr_range.go", - "addr_range_seq_unsafe.go", "bytes_io.go", "bytes_io_unsafe.go", "usermem.go", - "usermem_arm64.go", - "usermem_x86.go", ], visibility = ["//:sandbox"], deps = [ "//pkg/atomicbitops", - "//pkg/binary", "//pkg/context", "//pkg/gohacks", - "//pkg/log", + "//pkg/hostarch", "//pkg/safemem", "//pkg/syserror", - "@org_golang_x_sys//unix:go_default_library", ], ) @@ -44,12 +24,12 @@ go_test( name = "usermem_test", size = "small", srcs = [ - "addr_range_seq_test.go", "usermem_test.go", ], library = ":usermem", deps = [ "//pkg/context", + "//pkg/hostarch", "//pkg/safemem", "//pkg/syserror", ], diff --git a/pkg/usermem/bytes_io.go b/pkg/usermem/bytes_io.go index e177d30eb..3da3c0294 100644 --- a/pkg/usermem/bytes_io.go +++ b/pkg/usermem/bytes_io.go @@ -16,6 +16,7 @@ package usermem import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -30,7 +31,7 @@ type BytesIO struct { } // CopyOut implements IO.CopyOut. -func (b *BytesIO) CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) { +func (b *BytesIO) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts IOOpts) (int, error) { rngN, rngErr := b.rangeCheck(addr, len(src)) if rngN == 0 { return 0, rngErr @@ -39,7 +40,7 @@ func (b *BytesIO) CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpt } // CopyIn implements IO.CopyIn. -func (b *BytesIO) CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) { +func (b *BytesIO) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts IOOpts) (int, error) { rngN, rngErr := b.rangeCheck(addr, len(dst)) if rngN == 0 { return 0, rngErr @@ -48,7 +49,7 @@ func (b *BytesIO) CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts } // ZeroOut implements IO.ZeroOut. -func (b *BytesIO) ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) { +func (b *BytesIO) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error) { if toZero > int64(maxInt) { return 0, syserror.EINVAL } @@ -64,7 +65,7 @@ func (b *BytesIO) ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOO } // CopyOutFrom implements IO.CopyOutFrom. -func (b *BytesIO) CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) { +func (b *BytesIO) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) { dsts, rngErr := b.blocksFromAddrRanges(ars) n, err := src.ReadToBlocks(dsts) if err != nil { @@ -74,7 +75,7 @@ func (b *BytesIO) CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem } // CopyInTo implements IO.CopyInTo. -func (b *BytesIO) CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) { +func (b *BytesIO) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) { srcs, rngErr := b.blocksFromAddrRanges(ars) n, err := dst.WriteFromBlocks(srcs) if err != nil { @@ -83,14 +84,14 @@ func (b *BytesIO) CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Wr return int64(n), rngErr } -func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) { +func (b *BytesIO) rangeCheck(addr hostarch.Addr, length int) (int, error) { if length == 0 { return 0, nil } if length < 0 { return 0, syserror.EINVAL } - max := Addr(len(b.Bytes)) + max := hostarch.Addr(len(b.Bytes)) if addr >= max { return 0, syserror.EFAULT } @@ -101,7 +102,7 @@ func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) { return length, nil } -func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) { +func (b *BytesIO) blocksFromAddrRanges(ars hostarch.AddrRangeSeq) (safemem.BlockSeq, error) { switch ars.NumRanges() { case 0: return safemem.BlockSeq{}, nil @@ -124,7 +125,7 @@ func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, erro } } -func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) { +func (b *BytesIO) blockFromAddrRange(ar hostarch.AddrRange) (safemem.Block, error) { n, err := b.rangeCheck(ar.Start, int(ar.Length())) if n == 0 { return safemem.Block{}, err @@ -136,6 +137,6 @@ func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) { func BytesIOSequence(buf []byte) IOSequence { return IOSequence{ IO: &BytesIO{buf}, - Addrs: AddrRangeSeqOf(AddrRange{0, Addr(len(buf))}), + Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(len(buf))}), } } diff --git a/pkg/usermem/bytes_io_unsafe.go b/pkg/usermem/bytes_io_unsafe.go index 20de5037d..dcd5c81d1 100644 --- a/pkg/usermem/bytes_io_unsafe.go +++ b/pkg/usermem/bytes_io_unsafe.go @@ -20,10 +20,11 @@ import ( "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" ) // SwapUint32 implements IO.SwapUint32. -func (b *BytesIO) SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) { +func (b *BytesIO) SwapUint32(ctx context.Context, addr hostarch.Addr, new uint32, opts IOOpts) (uint32, error) { if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil { return 0, rngErr } @@ -31,7 +32,7 @@ func (b *BytesIO) SwapUint32(ctx context.Context, addr Addr, new uint32, opts IO } // CompareAndSwapUint32 implements IO.CompareAndSwapUint32. -func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) { +func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr hostarch.Addr, old, new uint32, opts IOOpts) (uint32, error) { if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil { return 0, rngErr } @@ -39,7 +40,7 @@ func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr Addr, old, new } // LoadUint32 implements IO.LoadUint32. -func (b *BytesIO) LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) { +func (b *BytesIO) LoadUint32(ctx context.Context, addr hostarch.Addr, opts IOOpts) (uint32, error) { if _, err := b.rangeCheck(addr, 4); err != nil { return 0, err } diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index dc2571154..0d6d25e50 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -25,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" + + "gvisor.dev/gvisor/pkg/hostarch" ) // IO provides access to the contents of a virtual memory space. @@ -37,7 +39,7 @@ type IO interface { // any following locks in the lock order. // // Postconditions: CopyOut does not retain src. - CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) + CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts IOOpts) (int, error) // CopyIn copies len(dst) bytes from the memory mapped at addr to dst. // It returns the number of bytes copied. If the number of bytes copied is @@ -47,7 +49,7 @@ type IO interface { // any following locks in the lock order. // // Postconditions: CopyIn does not retain dst. - CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) + CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts IOOpts) (int, error) // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a @@ -57,7 +59,7 @@ type IO interface { // * The caller must not hold mm.MemoryManager.mappingMu or any // following locks in the lock order. // * toZero >= 0. - ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) + ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error) // CopyOutFrom copies ars.NumBytes() bytes from src to the memory mapped at // ars. It returns the number of bytes copied, which may be less than the @@ -72,7 +74,7 @@ type IO interface { // following locks in the lock order. // * src.ReadToBlocks must not block on mm.MemoryManager.activeMu or // any preceding locks in the lock order. - CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) + CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) // CopyInTo copies ars.NumBytes() bytes from the memory mapped at ars to // dst. It returns the number of bytes copied. CopyInTo may return a @@ -86,7 +88,7 @@ type IO interface { // following locks in the lock order. // * dst.WriteFromBlocks must not block on mm.MemoryManager.activeMu or // any preceding locks in the lock order. - CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) + CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) // TODO(jamieliu): The requirement that CopyOutFrom/CopyInTo call src/dst // at most once, which is unnecessary in most cases, forces implementations @@ -101,7 +103,7 @@ type IO interface { // * The caller must not hold mm.MemoryManager.mappingMu or any // following locks in the lock order. // * addr must be aligned to a 4-byte boundary. - SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) + SwapUint32(ctx context.Context, addr hostarch.Addr, new uint32, opts IOOpts) (uint32, error) // CompareAndSwapUint32 atomically compares the uint32 value at addr to // old; if they are equal, the value in memory is replaced by new. In @@ -111,7 +113,7 @@ type IO interface { // * The caller must not hold mm.MemoryManager.mappingMu or any // following locks in the lock order. // * addr must be aligned to a 4-byte boundary. - CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) + CompareAndSwapUint32(ctx context.Context, addr hostarch.Addr, old, new uint32, opts IOOpts) (uint32, error) // LoadUint32 atomically loads the uint32 value at addr and returns it. // @@ -119,7 +121,7 @@ type IO interface { // * The caller must not hold mm.MemoryManager.mappingMu or any // following locks in the lock order. // * addr must be aligned to a 4-byte boundary. - LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) + LoadUint32(ctx context.Context, addr hostarch.Addr, opts IOOpts) (uint32, error) } // IOOpts contains options applicable to all IO methods. @@ -142,7 +144,7 @@ type IOOpts struct { type IOReadWriter struct { Ctx context.Context IO IO - Addr Addr + Addr hostarch.Addr Opts IOOpts } @@ -159,7 +161,7 @@ func (rw *IOReadWriter) Read(dst []byte) (int, error) { rw.Addr = end } else { // Disallow wraparound. - rw.Addr = ^Addr(0) + rw.Addr = ^hostarch.Addr(0) if err != nil { err = syserror.EFAULT } @@ -175,7 +177,7 @@ func (rw *IOReadWriter) Write(src []byte) (int, error) { rw.Addr = end } else { // Disallow wraparound. - rw.Addr = ^Addr(0) + rw.Addr = ^hostarch.Addr(0) if err != nil { err = syserror.EFAULT } @@ -197,7 +199,7 @@ const ( // // Preconditions: Same as IO.CopyFromUser, plus: // * maxlen >= 0. -func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) { +func CopyStringIn(ctx context.Context, uio IO, addr hostarch.Addr, maxlen int, opts IOOpts) (string, error) { initLen := maxlen if initLen > copyStringMaxInitBufLen { initLen = copyStringMaxInitBufLen @@ -251,12 +253,12 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt // the maximum, it returns a non-nil error explaining why. // // Preconditions: Same as IO.CopyOut. -func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts IOOpts) (int, error) { +func CopyOutVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, src []byte, opts IOOpts) (int, error) { var done int for !ars.IsEmpty() && done < len(src) { ar := ars.Head() cplen := len(src) - done - if Addr(cplen) >= ar.Length() { + if hostarch.Addr(cplen) >= ar.Length() { cplen = int(ar.Length()) } n, err := uio.CopyOut(ctx, ar.Start, src[done:done+cplen], opts) @@ -275,12 +277,12 @@ func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts // maximum, it returns a non-nil error explaining why. // // Preconditions: Same as IO.CopyIn. -func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts IOOpts) (int, error) { +func CopyInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, dst []byte, opts IOOpts) (int, error) { var done int for !ars.IsEmpty() && done < len(dst) { ar := ars.Head() cplen := len(dst) - done - if Addr(cplen) >= ar.Length() { + if hostarch.Addr(cplen) >= ar.Length() { cplen = int(ar.Length()) } n, err := uio.CopyIn(ctx, ar.Start, dst[done:done+cplen], opts) @@ -299,12 +301,12 @@ func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts I // maximum, it returns a non-nil error explaining why. // // Preconditions: Same as IO.ZeroOut. -func ZeroOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) { +func ZeroOutVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) { var done int64 for !ars.IsEmpty() && done < toZero { ar := ars.Head() cplen := toZero - done - if Addr(cplen) >= ar.Length() { + if hostarch.Addr(cplen) >= ar.Length() { cplen = int64(ar.Length()) } n, err := uio.ZeroOut(ctx, ar.Start, cplen, opts) @@ -352,7 +354,7 @@ func isASCIIWhitespace(b byte) bool { // - CopyInt32StringsInVec returns EINVAL if ars.NumBytes() == 0. // // Preconditions: Same as CopyInVec. -func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) { +func CopyInt32StringsInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) { if len(dsts) == 0 { return 0, nil } @@ -403,7 +405,7 @@ func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts [ // CopyInt32StringInVec is equivalent to CopyInt32StringsInVec, but copies at // most one int32. -func CopyInt32StringInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst *int32, opts IOOpts) (int64, error) { +func CopyInt32StringInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, dst *int32, opts IOOpts) (int64, error) { dsts := [1]int32{*dst} n, err := CopyInt32StringsInVec(ctx, uio, ars, dsts[:], opts) *dst = dsts[0] @@ -413,7 +415,7 @@ func CopyInt32StringInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst *in // IOSequence holds arguments to IO methods. type IOSequence struct { IO IO - Addrs AddrRangeSeq + Addrs hostarch.AddrRangeSeq Opts IOOpts } @@ -444,28 +446,28 @@ func (s IOSequence) NumBytes() int64 { // DropFirst returns a copy of s with s.Addrs.DropFirst(n). // -// Preconditions: Same as AddrRangeSeq.DropFirst. +// Preconditions: Same as hostarch.AddrRangeSeq.DropFirst. func (s IOSequence) DropFirst(n int) IOSequence { return IOSequence{s.IO, s.Addrs.DropFirst(n), s.Opts} } // DropFirst64 returns a copy of s with s.Addrs.DropFirst64(n). // -// Preconditions: Same as AddrRangeSeq.DropFirst64. +// Preconditions: Same as hostarch.AddrRangeSeq.DropFirst64. func (s IOSequence) DropFirst64(n int64) IOSequence { return IOSequence{s.IO, s.Addrs.DropFirst64(n), s.Opts} } // TakeFirst returns a copy of s with s.Addrs.TakeFirst(n). // -// Preconditions: Same as AddrRangeSeq.TakeFirst. +// Preconditions: Same as hostarch.AddrRangeSeq.TakeFirst. func (s IOSequence) TakeFirst(n int) IOSequence { return IOSequence{s.IO, s.Addrs.TakeFirst(n), s.Opts} } // TakeFirst64 returns a copy of s with s.Addrs.TakeFirst64(n). // -// Preconditions: Same as AddrRangeSeq.TakeFirst64. +// Preconditions: Same as hostarch.AddrRangeSeq.TakeFirst64. func (s IOSequence) TakeFirst64(n int64) IOSequence { return IOSequence{s.IO, s.Addrs.TakeFirst64(n), s.Opts} } diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go index da60b0cc7..9b697b593 100644 --- a/pkg/usermem/usermem_test.go +++ b/pkg/usermem/usermem_test.go @@ -22,6 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -106,7 +107,7 @@ func TestBytesIOZeroOutFailure(t *testing.T) { func TestBytesIOCopyOutFromSuccess(t *testing.T) { b := newBytesIOString("ABCDEFGH") - n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ + n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ {Start: 4, End: 7}, {Start: 1, End: 4}, }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{}) @@ -120,7 +121,7 @@ func TestBytesIOCopyOutFromSuccess(t *testing.T) { func TestBytesIOCopyOutFromFailure(t *testing.T) { b := newBytesIOString("ABCDE") - n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ + n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ {Start: 1, End: 4}, {Start: 4, End: 7}, }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{}) @@ -135,7 +136,7 @@ func TestBytesIOCopyOutFromFailure(t *testing.T) { func TestBytesIOCopyInToSuccess(t *testing.T) { b := newBytesIOString("AfoobarH") var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ + n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ {Start: 4, End: 7}, {Start: 1, End: 4}, }), safemem.FromIOWriter{&dst}, IOOpts{}) @@ -150,7 +151,7 @@ func TestBytesIOCopyInToSuccess(t *testing.T) { func TestBytesIOCopyInToFailure(t *testing.T) { b := newBytesIOString("Afoob") var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ + n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ {Start: 1, End: 4}, {Start: 4, End: 7}, }), safemem.FromIOWriter{&dst}, IOOpts{}) |