diff options
Diffstat (limited to 'pkg')
41 files changed, 969 insertions, 498 deletions
diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go index 0f8788f3b..170c15705 100644 --- a/pkg/lisafs/client_file.go +++ b/pkg/lisafs/client_file.go @@ -15,6 +15,8 @@ package lisafs import ( + "fmt" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -121,41 +123,92 @@ func (f *ClientFD) Sync(ctx context.Context) error { return err } +// chunkify applies fn to buf in chunks based on chunkSize. +func chunkify(chunkSize uint64, buf []byte, fn func([]byte, uint64) (uint64, error)) (uint64, error) { + toProcess := uint64(len(buf)) + var ( + totalProcessed uint64 + curProcessed uint64 + off uint64 + err error + ) + for { + if totalProcessed == toProcess { + return totalProcessed, nil + } + + if totalProcessed+chunkSize > toProcess { + curProcessed, err = fn(buf[totalProcessed:], off) + } else { + curProcessed, err = fn(buf[totalProcessed:totalProcessed+chunkSize], off) + } + totalProcessed += curProcessed + off += curProcessed + + if err != nil { + return totalProcessed, err + } + + // Return partial result immediately. + if curProcessed < chunkSize { + return totalProcessed, nil + } + + // If we received more bytes than we ever requested, this is a problem. + if totalProcessed > toProcess { + panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", totalProcessed, toProcess)) + } + } +} + // Read makes the PRead RPC. func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) { - req := PReadReq{ - Offset: offset, - FD: f.fd, - Count: uint32(len(dst)), - } + var resp PReadResp + // maxDataReadSize represents the maximum amount of data we can read at once + // (maximum message size - metadata size present in resp). Uninitialized + // resp.SizeBytes() correctly returns the metadata size only (since the read + // buffer is empty). + maxDataReadSize := uint64(f.client.maxMessageSize) - uint64(resp.SizeBytes()) + return chunkify(maxDataReadSize, dst, func(buf []byte, curOff uint64) (uint64, error) { + req := PReadReq{ + Offset: offset + curOff, + FD: f.fd, + Count: uint32(len(buf)), + } - resp := PReadResp{ // This will be unmarshalled into. Already set Buf so that we don't need to // allocate a temporary buffer during unmarshalling. // PReadResp.UnmarshalBytes expects this to be set. - Buf: dst, - } - - ctx.UninterruptibleSleepStart(false) - err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) - ctx.UninterruptibleSleepFinish(false) - return uint64(resp.NumBytes), err + resp.Buf = buf + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return uint64(resp.NumBytes), err + }) } // Write makes the PWrite RPC. func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) { - req := PWriteReq{ - Offset: primitive.Uint64(offset), - FD: f.fd, - NumBytes: primitive.Uint32(len(src)), - Buf: src, - } + var req PWriteReq + // maxDataWriteSize represents the maximum amount of data we can write at + // once (maximum message size - metadata size present in req). Uninitialized + // req.SizeBytes() correctly returns the metadata size only (since the write + // buffer is empty). + maxDataWriteSize := uint64(f.client.maxMessageSize) - uint64(req.SizeBytes()) + return chunkify(maxDataWriteSize, src, func(buf []byte, curOff uint64) (uint64, error) { + req = PWriteReq{ + Offset: primitive.Uint64(offset + curOff), + FD: f.fd, + NumBytes: primitive.Uint32(len(buf)), + Buf: buf, + } - var resp PWriteResp - ctx.UninterruptibleSleepStart(false) - err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) - ctx.UninterruptibleSleepFinish(false) - return resp.Count, err + var resp PWriteResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Count, err + }) } // MkdirAt makes the MkdirAt RPC. diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go index 476ff76a5..5fc7c364d 100644 --- a/pkg/lisafs/testsuite/testsuite.go +++ b/pkg/lisafs/testsuite/testsuite.go @@ -330,8 +330,8 @@ func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root li defer closeFD(ctx, t, fd) defer unix.Close(hostFD) - // Test Read/Write RPCs. - data := make([]byte, 100) + // Test Read/Write RPCs with 2MB of data to test IO in chunks. + data := make([]byte, 1<<21) rand.Read(data) if err := writeFD(ctx, t, fd, 0, data); err != nil { t.Fatalf("write failed: %v", err) diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 9dac53c80..3f17fba49 100644 --- a/pkg/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -322,12 +322,3 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc func (p *PageTables) MarkReadOnlyShared() { p.readOnlyShared = true } - -// PrefaultRootTable touches the root table page to be sure that its physical -// pages are mapped. -// -//go:nosplit -//go:noinline -func (p *PageTables) PrefaultRootTable() PTE { - return p.root[0] -} diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index f322d2747..7fcb2d26b 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -84,6 +84,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + k := kernel.KernelFromContext(ctx) + fsDirChildren := make(map[string]kernfs.Inode) + // Create an empty directory to serve as the mount point for cgroupfs when + // cgroups are available. This emulates Linux behaviour, see + // kernel/cgroup.c:cgroup_init(). Note that in Linux, userspace (typically + // the init process) is ultimately responsible for actually mounting + // cgroupfs, but the kernel creates the mountpoint. For the sentry, the + // launcher mounts cgroupfs. + if k.CgroupRegistry() != nil { + fsDirChildren["cgroup"] = fs.newDir(ctx, creds, defaultSysDirMode, nil) + } + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), @@ -97,7 +109,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt }), }), "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), - "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, fsDirChildren), "kernel": kernelDir(ctx, fs, creds), "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 0a0d914cc..0c46a3a13 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -87,3 +87,17 @@ func TestSysRootContainsExpectedEntries(t *testing.T) { "power": linux.DT_DIR, }) } + +func TestCgroupMountpointExists(t *testing.T) { + // Note: The mountpoint is only created if cgroups are available. This is + // the VFS2 implementation of sysfs and the test runs with VFS2 enabled, so + // we expect to see the mount point unconditionally. + s := newTestSystem(t) + defer s.Destroy() + pop := s.PathOpAtRoot("/fs") + s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ + "cgroup": linux.DT_DIR, + }) + pop = s.PathOpAtRoot("/fs/cgroup") + s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ /*empty*/ }) +} diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 5bba9de0b..2363cec5f 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -1,13 +1,26 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) +go_template_instance( + name = "atomicptr_netns", + out = "atomicptr_netns_unsafe.go", + package = "inet", + prefix = "Namespace", + template = "//pkg/sync/atomicptr:generic_atomicptr", + types = { + "Value": "Namespace", + }, +) + go_library( name = "inet", srcs = [ + "atomicptr_netns_unsafe.go", "context.go", "inet.go", "namespace.go", diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 9a95bf44c..b0004482c 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -511,9 +511,7 @@ type Task struct { numaNodeMask uint64 // netns is the task's network namespace. netns is never nil. - // - // netns is protected by mu. - netns *inet.Namespace + netns inet.NamespaceAtomicPtr // If rseqPreempted is true, before the next call to p.Switch(), // interrupt rseq critical regions as defined by rseqAddr and diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 26a981f36..a6d8fb163 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -444,7 +444,7 @@ func (t *Task) Unshare(flags int32) error { t.mu.Unlock() return linuxerr.EPERM } - t.netns = inet.NewNamespace(t.netns) + t.netns.Store(inet.NewNamespace(t.netns.Load())) } if flags&linux.CLONE_NEWUTS != 0 { if !haveCapSysAdmin { diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go index f7711232c..e31e2b2e8 100644 --- a/pkg/sentry/kernel/task_net.go +++ b/pkg/sentry/kernel/task_net.go @@ -20,9 +20,7 @@ import ( // IsNetworkNamespaced returns true if t is in a non-root network namespace. func (t *Task) IsNetworkNamespaced() bool { - t.mu.Lock() - defer t.mu.Unlock() - return !t.netns.IsRoot() + return !t.netns.Load().IsRoot() } // NetworkContext returns the network stack used by the task. NetworkContext @@ -31,14 +29,10 @@ func (t *Task) IsNetworkNamespaced() bool { // TODO(gvisor.dev/issue/1833): Migrate callers of this method to // NetworkNamespace(). func (t *Task) NetworkContext() inet.Stack { - t.mu.Lock() - defer t.mu.Unlock() - return t.netns.Stack() + return t.netns.Load().Stack() } // NetworkNamespace returns the network namespace observed by the task. func (t *Task) NetworkNamespace() *inet.Namespace { - t.mu.Lock() - defer t.mu.Unlock() - return t.netns + return t.netns.Load() } diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 217c6f531..4919dea7c 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -140,7 +140,6 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { allowedCPUMask: cfg.AllowedCPUMask.Copy(), ioUsage: &usage.IO{}, niceness: cfg.Niceness, - netns: cfg.NetworkNamespace, utsns: cfg.UTSNamespace, ipcns: cfg.IPCNamespace, abstractSockets: cfg.AbstractSocketNamespace, @@ -152,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { containerID: cfg.ContainerID, cgroups: make(map[Cgroup]struct{}), } + t.netns.Store(cfg.NetworkNamespace) t.creds.Store(cfg.Credentials) t.endStopCond.L = &t.tg.signalHandlers.mu t.ptraceTracer.Store((*Task)(nil)) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 8a490b3de..a26f54269 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,13 +1,26 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "atomicptr_machine", + out = "atomicptr_machine_unsafe.go", + package = "kvm", + prefix = "machine", + template = "//pkg/sync/atomicptr:generic_atomicptr", + types = { + "Value": "machine", + }, +) + go_library( name = "kvm", srcs = [ "address_space.go", "address_space_amd64.go", "address_space_arm64.go", + "atomicptr_machine_unsafe.go", "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", @@ -69,10 +82,17 @@ go_test( "kvm_amd64_test.go", "kvm_amd64_test.s", "kvm_arm64_test.go", + "kvm_safecopy_test.go", "kvm_test.go", "virtual_map_test.go", ], library = ":kvm", + # FIXME(gvisor.dev/issue/3374): Not working with all build systems. + nogo = False, + # cgo has to be disabled. We have seen libc that blocks all signals and + # calls mmap from pthread_create, but we use SIGSYS to trap mmap system + # calls. + pure = True, tags = [ "manual", "nogotsan", @@ -81,8 +101,10 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/hostarch", + "//pkg/memutil", "//pkg/ring0", "//pkg/ring0/pagetables", + "//pkg/safecopy", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", "//pkg/sentry/platform", diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index bb9967b9f..826997e77 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -61,6 +61,9 @@ var ( // This is called by bluepillHandler. savedHandler uintptr + // savedSigsysHandler is a pointer to the previos handler of the SIGSYS signals. + savedSigsysHandler uintptr + // dieTrampolineAddr is the address of dieTrampoline. dieTrampolineAddr uintptr ) diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index c2a1dca11..5d8358f64 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -32,6 +32,8 @@ // This is checked as the source of the fault. #define CLI $0xfa +#define SYS_MMAP 9 + // See bluepill.go. TEXT ·bluepill(SB),NOSPLIT,$0 begin: @@ -95,6 +97,31 @@ TEXT ·addrOfSighandler(SB), $0-8 MOVQ AX, ret+0(FP) RET +TEXT ·sigsysHandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel. + MOVQ $1, CX + CMPL CX, 0x8(SI) + JNE fallback + + MOVL CONTEXT_RAX(DX), CX + CMPL CX, $SYS_MMAP + JNE fallback + PUSHQ DX // First argument (context). + CALL ·seccompMmapHandler(SB) // Call the handler. + POPQ DX // Discard the argument. + RET +fallback: + // Jump to the previous signal handler. + XORQ CX, CX + MOVQ ·savedSigsysHandler(SB), AX + JMP AX + +// func addrOfSighandler() uintptr +TEXT ·addrOfSigsysHandler(SB), $0-8 + MOVQ $·sigsysHandler(SB), AX + MOVQ AX, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 PUSHQ BX // First argument (vCPU). diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 308f2a951..9690e3772 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -29,9 +29,12 @@ // Only limited use of the context is done in the assembly stub below, most is // done in the Go handlers. #define SIGINFO_SIGNO 0x0 +#define SIGINFO_CODE 0x8 #define CONTEXT_PC 0x1B8 #define CONTEXT_R0 0xB8 +#define SYS_MMAP 222 + // getTLS returns the value of TPIDR_EL0 register. TEXT ·getTLS(SB),NOSPLIT,$0-8 MRS TPIDR_EL0, R1 @@ -98,6 +101,37 @@ TEXT ·addrOfSighandler(SB), $0-8 MOVD R0, ret+0(FP) RET +// The arguments are the following: +// +// R0 - The signal number. +// R1 - Pointer to siginfo_t structure. +// R2 - Pointer to ucontext structure. +// +TEXT ·sigsysHandler(SB),NOSPLIT,$0 + // si_code should be SYS_SECCOMP. + MOVD SIGINFO_CODE(R1), R7 + CMPW $1, R7 + BNE fallback + + CMPW $SYS_MMAP, R8 + BNE fallback + + MOVD R2, 8(RSP) + BL ·seccompMmapHandler(SB) // Call the handler. + + RET + +fallback: + // Jump to the previous signal handler. + MOVD ·savedHandler(SB), R7 + B (R7) + +// func addrOfSighandler() uintptr +TEXT ·addrOfSigsysHandler(SB), $0-8 + MOVD $·sigsysHandler(SB), R0 + MOVD R0, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 // R0: Fake the old PC as caller diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 0f0c1e73b..e38ca05c0 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -193,36 +193,8 @@ func bluepillHandler(context unsafe.Pointer) { return } - // Increment the fault count. - atomic.AddUint32(&c.faults, 1) - - // For MMIO, the physical address is the first data item. - physical = uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) - if !ok { - c.die(bluepillArchContext(context), "invalid physical address") - return - } - - // We now need to fill in the data appropriately. KVM - // expects us to provide the result of the given MMIO - // operation in the runData struct. This is safe - // because, if a fault occurs here, the same fault - // would have occurred in guest mode. The kernel should - // not create invalid page table mappings. - data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1])) - length := (uintptr)((uint32)(c.runData.data[2])) - write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0 - for i := uintptr(0); i < length; i++ { - b := bytePtr(uintptr(virtual) + i) - if write { - // Write to the given address. - *b = data[i] - } else { - // Read from the given address. - data[i] = *b - } - } + c.die(bluepillArchContext(context), "exit_mmio") + return case _KVM_EXIT_IRQ_WINDOW_OPEN: bluepillStopGuest(c) case _KVM_EXIT_SHUTDOWN: diff --git a/pkg/sentry/platform/kvm/kvm_safecopy_test.go b/pkg/sentry/platform/kvm/kvm_safecopy_test.go new file mode 100644 index 000000000..9a87c9e6f --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_safecopy_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// FIXME(gvisor.dev/issue//6629): These tests don't pass on ARM64. +// +//go:build amd64 +// +build amd64 + +package kvm + +import ( + "fmt" + "os" + "testing" + "unsafe" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/safecopy" +) + +func testSafecopy(t *testing.T, mapSize uintptr, fileSize uintptr, testFunc func(t *testing.T, c *vCPU, addr uintptr)) { + memfd, err := memutil.CreateMemFD(fmt.Sprintf("kvm_test_%d", os.Getpid()), 0) + if err != nil { + t.Errorf("error creating memfd: %v", err) + } + + memfile := os.NewFile(uintptr(memfd), "kvm_test") + memfile.Truncate(int64(fileSize)) + kvmTest(t, nil, func(c *vCPU) bool { + const n = 10 + mappings := make([]uintptr, n) + defer func() { + for i := 0; i < n && mappings[i] != 0; i++ { + unix.RawSyscall( + unix.SYS_MUNMAP, + mappings[i], mapSize, 0) + } + }() + for i := 0; i < n; i++ { + addr, _, errno := unix.RawSyscall6( + unix.SYS_MMAP, + 0, + mapSize, + unix.PROT_READ|unix.PROT_WRITE, + unix.MAP_SHARED|unix.MAP_FILE, + uintptr(memfile.Fd()), + 0) + if errno != 0 { + t.Errorf("error mapping file: %v", errno) + } + mappings[i] = addr + testFunc(t, c, addr) + } + return false + }) +} + +func TestSafecopySigbus(t *testing.T) { + mapSize := uintptr(faultBlockSize) + fileSize := mapSize - hostarch.PageSize + buf := make([]byte, hostarch.PageSize) + testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) { + want := safecopy.BusError{addr + fileSize} + bluepill(c) + _, err := safecopy.CopyIn(buf, unsafe.Pointer(addr+fileSize)) + if err != want { + t.Errorf("expected error: got %v, want %v", err, want) + } + }) +} + +func TestSafecopy(t *testing.T) { + mapSize := uintptr(faultBlockSize) + fileSize := mapSize + testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) { + want := uint32(0x12345678) + bluepill(c) + _, err := safecopy.SwapUint32(unsafe.Pointer(addr+fileSize-8), want) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + bluepill(c) + val, err := safecopy.LoadUint32(unsafe.Pointer(addr + fileSize - 8)) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if val != want { + t.Errorf("incorrect value: got %x, want %x", val, want) + } + }) +} diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index d67563958..dcf34015d 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -17,15 +17,19 @@ package kvm import ( "fmt" "runtime" + gosync "sync" "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/seccomp" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" ) @@ -35,6 +39,9 @@ type machine struct { // fd is the vm fd. fd int + // machinePoolIndex is the index in the machinePool array. + machinePoolIndex uint32 + // nextSlot is the next slot for setMemoryRegion. // // This must be accessed atomically. If nextSlot is ^uint32(0), then @@ -192,6 +199,10 @@ func (m *machine) newVCPU() *vCPU { return c // Done. } +// readOnlyGuestRegions contains regions that have to be mapped read-only into +// the guest physical address space. Right now, it is used on arm64 only. +var readOnlyGuestRegions []region + // newMachine returns a new VM context. func newMachine(vm int) (*machine, error) { // Create the machine. @@ -227,6 +238,10 @@ func newMachine(vm int) (*machine, error) { m.upperSharedPageTables.MarkReadOnlyShared() m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress) + // Install seccomp rules to trap runtime mmap system calls. They will + // be handled by seccompMmapHandler. + seccompMmapRules(m) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -241,32 +256,11 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) - var physicalRegionsReadOnly []physicalRegion - var physicalRegionsAvailable []physicalRegion - - physicalRegionsReadOnly = rdonlyRegionsForSetMem() - physicalRegionsAvailable = availableRegionsForSetMem() - - // Map all read-only regions. - for _, r := range physicalRegionsReadOnly { - m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) - } - // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to // ensure successful vCPU entry. - applyVirtualRegions(func(vr virtualRegion) { - if excludeVirtualRegion(vr) { - return // skip region. - } - - for _, r := range physicalRegionsReadOnly { - if vr.virtual == r.virtual { - return - } - } - + mapRegion := func(vr region, flags uint32) { for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -280,9 +274,32 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) + m.mapPhysical(physical, length, physicalRegions, flags) virtual += length } + } + + for _, vr := range readOnlyGuestRegions { + mapRegion(vr, _KVM_MEM_READONLY) + } + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + for _, r := range readOnlyGuestRegions { + if vr.virtual == r.virtual { + return + } + } + // Take into account that the stack can grow down. + if vr.filename == "[stack]" { + vr.virtual -= 1 << 20 + vr.length += 1 << 20 + } + + mapRegion(vr.region, 0) + }) // Initialize architecture state. @@ -352,6 +369,10 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg func (m *machine) Destroy() { runtime.SetFinalizer(m, nil) + machinePoolMu.Lock() + machinePool[m.machinePoolIndex].Store(nil) + machinePoolMu.Unlock() + // Destroy vCPUs. for _, c := range m.vCPUsByID { if c == nil { @@ -683,3 +704,72 @@ func (c *vCPU) setSystemTimeLegacy() error { } } } + +const machinePoolSize = 16 + +// machinePool is enumerated from the seccompMmapHandler signal handler +var ( + machinePool [machinePoolSize]machineAtomicPtr + machinePoolLen uint32 + machinePoolMu sync.Mutex + seccompMmapRulesOnce gosync.Once +) + +func sigsysHandler() +func addrOfSigsysHandler() uintptr + +// seccompMmapRules adds seccomp rules to trap mmap system calls that will be +// handled in seccompMmapHandler. +func seccompMmapRules(m *machine) { + seccompMmapRulesOnce.Do(func() { + // Install the handler. + if err := safecopy.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) + } + rules := []seccomp.RuleSet{} + rules = append(rules, []seccomp.RuleSet{ + // Trap mmap system calls and handle them in sigsysGoHandler + { + Rules: seccomp.SyscallRules{ + unix.SYS_MMAP: { + { + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + /* MAP_DENYWRITE is ignored and used only for filtering. */ + seccomp.MaskedEqual(unix.MAP_DENYWRITE, 0), + }, + }, + }, + Action: linux.SECCOMP_RET_TRAP, + }, + }...) + instrs, err := seccomp.BuildProgram(rules, linux.SECCOMP_RET_ALLOW, linux.SECCOMP_RET_ALLOW) + if err != nil { + panic(fmt.Sprintf("failed to build rules: %v", err)) + } + // Perform the actual installation. + if err := seccomp.SetFilter(instrs); err != nil { + panic(fmt.Sprintf("failed to set filter: %v", err)) + } + }) + + machinePoolMu.Lock() + n := atomic.LoadUint32(&machinePoolLen) + i := uint32(0) + for ; i < n; i++ { + if machinePool[i].Load() == nil { + break + } + } + if i == n { + if i == machinePoolSize { + machinePoolMu.Unlock() + panic("machinePool is full") + } + atomic.AddUint32(&machinePoolLen, 1) + } + machinePool[i].Store(m) + m.machinePoolIndex = i + machinePoolMu.Unlock() +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index a96634381..ab1e036b7 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -309,22 +309,6 @@ func loadByte(ptr *byte) byte { return *ptr } -// prefaultFloatingPointState touches each page of the floating point state to -// be sure that its physical pages are mapped. -// -// Otherwise the kernel can trigger KVM_EXIT_MMIO and an instruction that -// triggered a fault will be emulated by the kvm kernel code, but it can't -// emulate instructions like xsave and xrstor. -// -//go:nosplit -func prefaultFloatingPointState(data *fpu.State) { - size := len(*data) - for i := 0; i < size; i += hostarch.PageSize { - loadByte(&(*data)[i]) - } - loadByte(&(*data)[size-1]) -} - // SwitchToUser unpacks architectural-details. func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. @@ -355,11 +339,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) // allocations occur. entersyscall() bluepill(c) - // The root table physical page has to be mapped to not fault in iret - // or sysret after switching into a user address space. sysret and - // iret are in the upper half that is global and already mapped. - switchOpts.PageTables.PrefaultRootTable() - prefaultFloatingPointState(switchOpts.FloatingPointState) vector = c.CPU.SwitchToUser(switchOpts) exitsyscall() @@ -522,3 +501,7 @@ func (m *machine) getNewVCPU() *vCPU { } return nil } + +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index de798bb2c..fbacea9ad 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -161,3 +161,15 @@ func (c *vCPU) getSystemRegisters(sregs *systemRegs) unix.Errno { } return 0 } + +//go:nosplit +func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) { + ctx := bluepillArchContext(context) + + // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters. + addr, _, e := unix.RawSyscall6(uintptr(ctx.Rax), uintptr(ctx.Rdi), uintptr(ctx.Rsi), + uintptr(ctx.Rdx), uintptr(ctx.R10)|unix.MAP_DENYWRITE, uintptr(ctx.R8), uintptr(ctx.R9)) + ctx.Rax = uint64(addr) + + return addr, uintptr(ctx.Rsi), e +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 7937a8481..08d98c479 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -110,18 +110,128 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { return phyRegions } +// archPhysicalRegions fills readOnlyGuestRegions and allocates separate +// physical regions form them. +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + if !vr.accessType.Write { + readOnlyGuestRegions = append(readOnlyGuestRegions, vr.region) + } + }) + + rdRegions := readOnlyGuestRegions[:] + + // Add an unreachable region. + rdRegions = append(rdRegions, region{ + virtual: 0xffffffffffffffff, + length: 0, + }) + + var regions []physicalRegion + addValidRegion := func(r *physicalRegion, virtual, length uintptr) { + if length == 0 { + return + } + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), + }) + } + i := 0 + for _, pr := range physicalRegions { + start := pr.virtual + end := pr.virtual + pr.length + for start < end { + rdRegion := rdRegions[i] + rdStart := rdRegion.virtual + rdEnd := rdRegion.virtual + rdRegion.length + if rdEnd <= start { + i++ + continue + } + if rdStart > start { + newEnd := rdStart + if end < rdStart { + newEnd = end + } + addValidRegion(&pr, start, newEnd-start) + start = rdStart + continue + } + if rdEnd < end { + addValidRegion(&pr, start, rdEnd-start) + start = rdEnd + continue + } + addValidRegion(&pr, start, end-start) + start = end + } + } + + return regions +} + // Get all available physicalRegions. -func availableRegionsForSetMem() (phyRegions []physicalRegion) { - var excludeRegions []region +func availableRegionsForSetMem() []physicalRegion { + var excludedRegions []region applyVirtualRegions(func(vr virtualRegion) { if !vr.accessType.Write { - excludeRegions = append(excludeRegions, vr.region) + excludedRegions = append(excludedRegions, vr.region) } }) - phyRegions = computePhysicalRegions(excludeRegions) + // Add an unreachable region. + excludedRegions = append(excludedRegions, region{ + virtual: 0xffffffffffffffff, + length: 0, + }) - return phyRegions + var regions []physicalRegion + addValidRegion := func(r *physicalRegion, virtual, length uintptr) { + if length == 0 { + return + } + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), + }) + } + i := 0 + for _, pr := range physicalRegions { + start := pr.virtual + end := pr.virtual + pr.length + for start < end { + er := excludedRegions[i] + excludeEnd := er.virtual + er.length + excludeStart := er.virtual + if excludeEnd < start { + i++ + continue + } + if excludeStart < start { + start = excludeEnd + i++ + continue + } + rend := excludeStart + if rend > end { + rend = end + } + addValidRegion(&pr, start, rend-start) + start = excludeEnd + } + } + + return regions } // nonCanonical generates a canonical address return. diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 1a4a9ce7d..7e8e19dcb 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -333,3 +333,15 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) } } + +//go:nosplit +func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) { + ctx := bluepillArchContext(context) + + // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters. + addr, _, e := unix.RawSyscall6(uintptr(ctx.Regs[8]), uintptr(ctx.Regs[0]), uintptr(ctx.Regs[1]), + uintptr(ctx.Regs[2]), uintptr(ctx.Regs[3])|unix.MAP_DENYWRITE, uintptr(ctx.Regs[4]), uintptr(ctx.Regs[5])) + ctx.Regs[0] = uint64(addr) + + return addr, uintptr(ctx.Regs[1]), e +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index cc3a1253b..cf3a4e7c9 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -171,3 +171,46 @@ func (c *vCPU) setSignalMask() error { return nil } + +// seccompMmapHandler is a signal handler for runtime mmap system calls +// that are trapped by seccomp. +// +// It executes the mmap syscall with specified arguments and maps a new region +// to the guest. +// +//go:nosplit +func seccompMmapHandler(context unsafe.Pointer) { + addr, length, errno := seccompMmapSyscall(context) + if errno != 0 { + return + } + + for i := uint32(0); i < atomic.LoadUint32(&machinePoolLen); i++ { + m := machinePool[i].Load() + if m == nil { + continue + } + + // Map the new region to the guest. + vr := region{ + virtual: addr, + length: length, + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { + physical, length, ok := translateToPhysical(virtual) + if !ok { + // This must be an invalid region that was + // knocked out by creation of the physical map. + return + } + if virtual+length > vr.virtual+vr.length { + // Cap the length to the end of the area. + length = vr.virtual + vr.length - virtual + } + + // Ensure the physical range is mapped. + m.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) + virtual += length + } + } +} diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index d812e6c26..9864d1258 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -168,6 +168,9 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica } addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd) + // Do arch-specific actions on physical regions. + physicalRegions = archPhysicalRegions(physicalRegions) + // Dump our all physical regions. for _, r := range physicalRegions { log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)", diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go index 6d0ba8252..346a10043 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -30,8 +30,8 @@ import ( func TLSWorks() bool // SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *arch.Registers, fn func()) { - regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) +func SetTestTarget(regs *arch.Registers, fn uintptr) { + regs.Pc = uint64(fn) } // SetTouchTarget sets rax appropriately. diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 7348c29a5..42876245a 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -28,6 +28,11 @@ TEXT ·Getpid(SB),NOSPLIT,$0 SVC RET +TEXT ·AddrOfGetpid(SB),NOSPLIT,$0-8 + MOVD $·Getpid(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·Touch(SB),NOSPLIT,$0 start: MOVD 0(R8), R1 @@ -35,21 +40,41 @@ start: SVC B start +TEXT ·AddrOfTouch(SB),NOSPLIT,$0-8 + MOVD $·Touch(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·HaltLoop(SB),NOSPLIT,$0 start: HLT B start +TEXT ·AddOfHaltLoop(SB),NOSPLIT,$0-8 + MOVD $·HaltLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + // This function simulates a loop of syscall. TEXT ·SyscallLoop(SB),NOSPLIT,$0 start: SVC B start +TEXT ·AddrOfSyscallLoop(SB),NOSPLIT,$0-8 + MOVD $·SyscallLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·SpinLoop(SB),NOSPLIT,$0 start: B start +TEXT ·AddrOfSpinLoop(SB),NOSPLIT,$0-8 + MOVD $·SpinLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·TLSWorks(SB),NOSPLIT,$0-8 NO_LOCAL_POINTERS MOVD $0x6789, R5 @@ -125,6 +150,11 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 SVC RET // never reached +TEXT ·AddrOfTwiddleRegsSyscall(SB),NOSPLIT,$0-8 + MOVD $·TwiddleRegsSyscall(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 TWIDDLE_REGS() MSR R10, TPIDR_EL0 @@ -132,3 +162,8 @@ TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 // Branch to Register branches unconditionally to an address in <Rn>. JMP (R6) // <=> br x6, must fault RET // never reached + +TEXT ·AddrOfTwiddleRegsFault(SB),NOSPLIT,$0-8 + MOVD $·TwiddleRegsFault(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index ea56f39c1..0f6e576a9 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID { } // Action implements stack.Target.Action. -func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go index 82b6df18c..7b9c2a4db 100644 --- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go +++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go @@ -37,6 +37,8 @@ func (p *AtomicPtr) loadPtr(v *Value) { // Load returns the value set by the most recent Store. It returns nil if there // has been no previous call to Store. +// +//go:nosplit func (p *AtomicPtr) Load() *Value { return (*Value)(atomic.LoadPointer(&p.ptr)) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 25f5a52e3..dda473e48 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -426,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -466,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -549,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -576,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -717,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -744,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -841,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -940,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index dab99d00d..e2d2cf907 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1100,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -1183,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 3617b6dd0..74c9075b4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -264,12 +264,62 @@ const ( chainReturn ) -// Check runs pkt through the rules for hook. It returns true when the packet +// CheckPrerouting performs the prerouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool { + return it.check(Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) +} + +// CheckInput performs the input hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { + return it.check(Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) +} + +// CheckForward performs the forward hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool { + return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) +} + +// CheckOutput performs the output hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool { + return it.check(Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// CheckPostrouting performs the postrouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool { + return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// check runs pkt through the rules for hook. It returns true when the packet // should continue traversing the network stack and false when it should be // dropped. // -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -300,7 +350,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -311,7 +361,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v { case RuleAccept: continue case RuleDrop: @@ -375,19 +425,35 @@ func (it *IPTables) startReaper(interval time.Duration) { }() } -// CheckPackets runs pkts through the rules for hook and returns a map of packets that -// should not go forward. +// CheckOutputPackets performs the output hook on the packets. // -// Preconditions: -// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// * pkt.NetworkHeader is not nil. +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return it.checkPackets(Output, pkts, r, outNicName) +} + +// CheckPostroutingPackets performs the postrouting hook on the packets. +// +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return it.checkPackets(Postrouting, pkts, r, outNicName) +} + +// checkPackets runs pkts through the rules for hook and returns a map of +// packets that should not go forward. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { + if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -407,11 +473,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -428,7 +494,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -454,7 +520,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -477,7 +543,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) + return rule.Target.Action(pkt, &it.connections, hook, r, addressEP) } // OriginalDst returns the original destination of redirected connections. It diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index de5997e9e..e8806ebdb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,7 +79,7 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleReturn, 0 } @@ -97,7 +97,7 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -117,6 +117,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. + var address tcpip.Address switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { @@ -125,7 +126,8 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r address = header.IPv6Loopback } case Prerouting: - // No-op, as address is already set correctly. + // addressEP is expected to be set for the prerouting hook. + address = addressEP.MainAddress().Address default: panic("redirect target is supported only on output and prerouting hooks") } @@ -180,7 +182,7 @@ type SNATTarget struct { } // Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if st.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 66e5f22ac..976194124 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -352,5 +352,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c5e896295..d45a2c05c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1269,6 +1269,8 @@ type TransportProtocolNumber uint32 type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. +// +// +stateify savable type StatCounter struct { count atomicbitops.AlignedAtomicUint64 } @@ -1995,6 +1997,8 @@ type Stats struct { } // ReceiveErrors collects packet receive errors within transport endpoint. +// +// +stateify savable type ReceiveErrors struct { // ReceiveBufferOverflow is the number of received packets dropped // due to the receive buffer being full. @@ -2012,8 +2016,10 @@ type ReceiveErrors struct { ChecksumErrors StatCounter } -// SendErrors collects packet send errors within the transport layer for -// an endpoint. +// SendErrors collects packet send errors within the transport layer for an +// endpoint. +// +// +stateify savable type SendErrors struct { // SendToNetworkFailed is the number of packets failed to be written to // the network endpoint. @@ -2024,6 +2030,8 @@ type SendErrors struct { } // ReadErrors collects segment read errors from an endpoint read call. +// +// +stateify savable type ReadErrors struct { // ReadClosed is the number of received packet drops because the endpoint // was shutdown for read. @@ -2039,6 +2047,8 @@ type ReadErrors struct { } // WriteErrors collects packet write errors from an endpoint write call. +// +// +stateify savable type WriteErrors struct { // WriteClosed is the number of packet drops because the endpoint // was shutdown for write. @@ -2054,6 +2064,8 @@ type WriteErrors struct { } // TransportEndpointStats collects statistics about the endpoint. +// +// +stateify savable type TransportEndpointStats struct { // PacketsReceived is the number of successful packet receives. PacketsReceived StatCounter diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index bbc0e3ecc..4718ec4ec 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/waiter", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 1e519085d..bb0db9f70 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,7 @@ package icmp import ( + "fmt" "io" "time" @@ -24,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -35,15 +38,6 @@ type icmpPacket struct { receivedAt time.Time `state:".(int64)"` } -type endpointState int - -const ( - stateInitial endpointState = iota - stateBound - stateConnected - stateClosed -) - // endpoint represents an ICMP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -51,14 +45,17 @@ const ( // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -70,38 +67,23 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - state endpointState - route *stack.Route `state:"manual"` - ttl uint8 - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool + ident uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, - state: stateInitial, uniqueID: s.UniqueID(), } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) + ep.net.Init(s, netProto, transProto, &ep.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -128,35 +110,40 @@ func (e *endpoint) Abort() { // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { - e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { - case stateBound, stateConnected: - bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) - } + notify := func() bool { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + return false + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + info := e.net.Info() + info.ID.LocalPort = e.ident + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice())) + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } - // Close the receive list and drain it. - e.rcvMu.Lock() - e.rcvClosed = true - e.rcvBufSize = 0 - for !e.rcvList.Empty() { - p := e.rcvList.Front() - e.rcvList.Remove(p) - } - e.rcvMu.Unlock() + e.net.Shutdown() + e.net.Close() - if e.route != nil { - e.route.Release() - e.route = nil - } - - // Update the state. - e.state = stateClosed + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } - e.mu.Unlock() + return true + }() - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + if notify { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + } } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. @@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -214,13 +201,12 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // // Returns true for retry if preparation should be retried. // +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.state { - case stateInitial: - case stateConnected: +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - - case stateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { - return 0, err + return network.WriteContext{}, 0, err } if !retry { @@ -298,36 +272,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } } - route := e.route - if to != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := to.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - - dst, netProto, err := e.checkV4MappedLocked(*to) - if err != nil { - return 0, err - } - - // Find the endpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) - if err != nil { - return 0, err - } - defer r.Release() + ctx, err := e.net.AcquireContextForWrite(opts) + return ctx, e.ident, err +} - route = r +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + ctx, ident, err := e.prepareForWrite(opts) + if err != nil { + return 0, err } + defer ctx.Release() // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) @@ -335,17 +289,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return 0, &tcpip.ErrBadBuffer{} } - var err tcpip.Error - switch e.NetProto { + switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) + if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } case header.IPv6ProtocolNumber: - err = send6(route, e.ID.LocalPort, v, e.ttl) - } - - if err != nil { - return 0, err + if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } + default: + panic(fmt.Sprintf("unhandled network protocol = %d", netProto)) } return int64(len(v)), nil @@ -358,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool { return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt sets a socket option. -func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return nil +// SetSockOpt implements tcpip.Endpoint. +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + return e.net.SetSockOpt(opt) } -// SetSockOptInt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - } - return nil + return e.net.SetSockOptInt(opt, v) } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -388,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.rcvMu.Lock() - v := int(e.ttl) - e.rcvMu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +// GetSockOpt implements tcpip.Endpoint. +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { +func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength), }) - pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber @@ -427,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest - icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - pkt.Data().AppendView(data) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V4.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { +func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength), }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) @@ -469,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest pkt.Data().AppendView(data) dataRange := pkt.Data().AsRange() icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpv6, - Src: r.LocalAddress(), - Dst: r.RemoteAddress(), + Src: src, + Dst: dst, PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V6.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() + return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() tcpip.Error { return &tcpip.ErrNotSupported{} @@ -516,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - localPort := uint16(0) - switch e.state { - case stateInitial: - case stateBound, stateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.ident - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + nextID, err := e.registerWithStack(netProto, nextID) + if err != nil { + return err } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } - - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress(), - LocalPort: localPort, - RemoteAddress: r.RemoteAddress(), - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - id, err = e.registerWithStack(nicID, netProtos, id) + e.ident = nextID.LocalPort + return nil + }) if err != nil { - r.Release() return err } - e.ID = id - e.route = r - e.RegisterNICID = nicID - - e.state = stateConnected - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -586,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - e.shutdownFlags |= flags - if e.state != stateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } + + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } } if flags&tcpip.ShutdownRead != 0 { @@ -616,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) - return id, err + return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) + err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -645,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err := e.registerWithStack(boundNetProto, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, err = e.registerWithStack(addr.NIC, netProtos, id) + e.ident = id.LocalPort + return nil + }) if err != nil { return err } - e.ID = id - e.RegisterNICID = addr.NIC - - // Mark endpoint as bound. - e.state = stateBound - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -692,7 +571,7 @@ func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || - e.stack.IsSubnetBroadcast(nicID, e.NetProto, addr) + e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr) } // Bind binds the endpoint to a specific local address and port. @@ -705,15 +584,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - err := e.bindLocked(addr) - if err != nil { - return err - } - - e.BindNICID = addr.NIC - e.BindAddr = addr.Addr - - return nil + return e.bindLocked(addr) } // GetLocalAddress returns the address to which the endpoint is bound. @@ -721,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.ident + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -733,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { - return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} + if addr, connected := e.net.GetRemoteAddress(); connected { + return addr, nil } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness returns the current readiness of the endpoint. For example, if @@ -766,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. - switch e.NetProto { + switch e.net.NetProto() { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { @@ -840,9 +705,9 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + defer e.mu.RUnlock() + ret := e.net.Info() + ret.ID.LocalPort = e.ident return &ret } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index b8b839e4a..dfe453ff9 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,13 @@ package icmp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.thaw() + + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - if e.state != stateBound && e.state != stateConnected { - return - } - - var err tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + var err tcpip.Error + info := e.net.Info() + info.ID.LocalPort = e.ident + info.ID, err = e.registerWithStack(info.NetProto, info.ID) if err != nil { - panic(err) + panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err)) } - - e.ID.LocalAddress = e.route.LocalAddress() - } else if len(e.ID.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) - if err != nil { - panic(err) + e.ident = info.ID.LocalPort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD index b1edce39b..3818cb04e 100644 --- a/pkg/tcpip/transport/internal/network/BUILD +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -9,6 +9,7 @@ go_library( "endpoint_state.go", ], visibility = [ + "//pkg/tcpip/transport/icmp:__pkg__", "//pkg/tcpip/transport/raw:__pkg__", "//pkg/tcpip/transport/udp:__pkg__", ], diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index e4a64e191..689427d53 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -67,7 +67,7 @@ type endpoint struct { waiterQueue *waiter.Queue cooked bool ops tcpip.SocketOptions - stats tcpip.TransportEndpointStats `state:"nosave"` + stats tcpip.TransportEndpointStats // The following fields are used to manage the receive queue. rcvMu sync.Mutex `state:"nosave"` diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 3040a445b..bfef75da7 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -70,7 +70,7 @@ type endpoint struct { associated bool net network.Endpoint - stats tcpip.TransportEndpointStats `state:"nosave"` + stats tcpip.TransportEndpointStats ops tcpip.SocketOptions // The following fields are used to manage the receive queue and are diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index ff0a5df9c..7115d0a12 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -193,14 +193,6 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } -func (l *listenContext) useSynCookies() bool { - var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies - if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { - panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) - } - return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) -} - // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { @@ -277,7 +269,7 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. - l.listenEP.propagateInheritableOptionsLocked(ep) + l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce if !ep.reserveTupleLocked() { ep.mu.Unlock() @@ -367,7 +359,6 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } -// Precondition: h.ep.mu must be held. // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep @@ -384,7 +375,7 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // cleanupCompletedHandshake transfers any state from the completed handshake to // the new endpoint. // -// Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep if l.listenEP != nil { @@ -404,7 +395,8 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. // -// Precondition: e.mu and n.mu must be held. +// +checklocks:e.mu +// +checklocks:n.mu func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout n.portFlags = e.portFlags @@ -415,9 +407,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // reserveTupleLocked reserves an accepted endpoint's tuple. // -// Preconditions: -// * propagateInheritableOptionsLocked has been called. -// * e.mu is held. +// Precondition: e.propagateInheritableOptionsLocked has been called. +// +// +checklocks:e.mu func (e *endpoint) reserveTupleLocked() bool { dest := tcpip.FullAddress{ Addr: e.TransportEndpointInfo.ID.RemoteAddress, @@ -459,7 +451,7 @@ func (e *endpoint) notifyAborted() { // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error { defer s.decRef() @@ -552,7 +544,7 @@ func (a *accepted) acceptQueueIsFullLocked() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed @@ -579,8 +571,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } + alwaysUseSynCookies := func() bool { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + return bool(alwaysUseSynCookies) + }() + opts := parseSynSegmentOptions(s) - if !ctx.useSynCookies() { + if !alwaysUseSynCookies && !e.synRcvdBacklogFull() { s.incRef() atomic.AddInt32(&e.synRcvdCount, 1) return e.handleSynSegment(ctx, s, opts) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index b355fa7eb..049957b81 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -60,9 +60,8 @@ type endpoint struct { waiterQueue *waiter.Queue uniqueID uint64 net network.Endpoint - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats tcpip.TransportEndpointStats `state:"nosave"` - ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. |