diff options
Diffstat (limited to 'pkg')
48 files changed, 1523 insertions, 591 deletions
diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go index 0f8788f3b..170c15705 100644 --- a/pkg/lisafs/client_file.go +++ b/pkg/lisafs/client_file.go @@ -15,6 +15,8 @@ package lisafs import ( + "fmt" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -121,41 +123,92 @@ func (f *ClientFD) Sync(ctx context.Context) error { return err } +// chunkify applies fn to buf in chunks based on chunkSize. +func chunkify(chunkSize uint64, buf []byte, fn func([]byte, uint64) (uint64, error)) (uint64, error) { + toProcess := uint64(len(buf)) + var ( + totalProcessed uint64 + curProcessed uint64 + off uint64 + err error + ) + for { + if totalProcessed == toProcess { + return totalProcessed, nil + } + + if totalProcessed+chunkSize > toProcess { + curProcessed, err = fn(buf[totalProcessed:], off) + } else { + curProcessed, err = fn(buf[totalProcessed:totalProcessed+chunkSize], off) + } + totalProcessed += curProcessed + off += curProcessed + + if err != nil { + return totalProcessed, err + } + + // Return partial result immediately. + if curProcessed < chunkSize { + return totalProcessed, nil + } + + // If we received more bytes than we ever requested, this is a problem. + if totalProcessed > toProcess { + panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", totalProcessed, toProcess)) + } + } +} + // Read makes the PRead RPC. func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) { - req := PReadReq{ - Offset: offset, - FD: f.fd, - Count: uint32(len(dst)), - } + var resp PReadResp + // maxDataReadSize represents the maximum amount of data we can read at once + // (maximum message size - metadata size present in resp). Uninitialized + // resp.SizeBytes() correctly returns the metadata size only (since the read + // buffer is empty). + maxDataReadSize := uint64(f.client.maxMessageSize) - uint64(resp.SizeBytes()) + return chunkify(maxDataReadSize, dst, func(buf []byte, curOff uint64) (uint64, error) { + req := PReadReq{ + Offset: offset + curOff, + FD: f.fd, + Count: uint32(len(buf)), + } - resp := PReadResp{ // This will be unmarshalled into. Already set Buf so that we don't need to // allocate a temporary buffer during unmarshalling. // PReadResp.UnmarshalBytes expects this to be set. - Buf: dst, - } - - ctx.UninterruptibleSleepStart(false) - err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) - ctx.UninterruptibleSleepFinish(false) - return uint64(resp.NumBytes), err + resp.Buf = buf + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil) + ctx.UninterruptibleSleepFinish(false) + return uint64(resp.NumBytes), err + }) } // Write makes the PWrite RPC. func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) { - req := PWriteReq{ - Offset: primitive.Uint64(offset), - FD: f.fd, - NumBytes: primitive.Uint32(len(src)), - Buf: src, - } + var req PWriteReq + // maxDataWriteSize represents the maximum amount of data we can write at + // once (maximum message size - metadata size present in req). Uninitialized + // req.SizeBytes() correctly returns the metadata size only (since the write + // buffer is empty). + maxDataWriteSize := uint64(f.client.maxMessageSize) - uint64(req.SizeBytes()) + return chunkify(maxDataWriteSize, src, func(buf []byte, curOff uint64) (uint64, error) { + req = PWriteReq{ + Offset: primitive.Uint64(offset + curOff), + FD: f.fd, + NumBytes: primitive.Uint32(len(buf)), + Buf: buf, + } - var resp PWriteResp - ctx.UninterruptibleSleepStart(false) - err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) - ctx.UninterruptibleSleepFinish(false) - return resp.Count, err + var resp PWriteResp + ctx.UninterruptibleSleepStart(false) + err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil) + ctx.UninterruptibleSleepFinish(false) + return resp.Count, err + }) } // MkdirAt makes the MkdirAt RPC. diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go index 476ff76a5..5fc7c364d 100644 --- a/pkg/lisafs/testsuite/testsuite.go +++ b/pkg/lisafs/testsuite/testsuite.go @@ -330,8 +330,8 @@ func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root li defer closeFD(ctx, t, fd) defer unix.Close(hostFD) - // Test Read/Write RPCs. - data := make([]byte, 100) + // Test Read/Write RPCs with 2MB of data to test IO in chunks. + data := make([]byte, 1<<21) rand.Read(data) if err := writeFD(ctx, t, fd, 0, data); err != nil { t.Fatalf("write failed: %v", err) diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 9dac53c80..3f17fba49 100644 --- a/pkg/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -322,12 +322,3 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc func (p *PageTables) MarkReadOnlyShared() { p.readOnlyShared = true } - -// PrefaultRootTable touches the root table page to be sure that its physical -// pages are mapped. -// -//go:nosplit -//go:noinline -func (p *PageTables) PrefaultRootTable() PTE { - return p.root[0] -} diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index f322d2747..7fcb2d26b 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -84,6 +84,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + k := kernel.KernelFromContext(ctx) + fsDirChildren := make(map[string]kernfs.Inode) + // Create an empty directory to serve as the mount point for cgroupfs when + // cgroups are available. This emulates Linux behaviour, see + // kernel/cgroup.c:cgroup_init(). Note that in Linux, userspace (typically + // the init process) is ultimately responsible for actually mounting + // cgroupfs, but the kernel creates the mountpoint. For the sentry, the + // launcher mounts cgroupfs. + if k.CgroupRegistry() != nil { + fsDirChildren["cgroup"] = fs.newDir(ctx, creds, defaultSysDirMode, nil) + } + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), @@ -97,7 +109,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt }), }), "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), - "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, fsDirChildren), "kernel": kernelDir(ctx, fs, creds), "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 0a0d914cc..0c46a3a13 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -87,3 +87,17 @@ func TestSysRootContainsExpectedEntries(t *testing.T) { "power": linux.DT_DIR, }) } + +func TestCgroupMountpointExists(t *testing.T) { + // Note: The mountpoint is only created if cgroups are available. This is + // the VFS2 implementation of sysfs and the test runs with VFS2 enabled, so + // we expect to see the mount point unconditionally. + s := newTestSystem(t) + defer s.Destroy() + pop := s.PathOpAtRoot("/fs") + s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ + "cgroup": linux.DT_DIR, + }) + pop = s.PathOpAtRoot("/fs/cgroup") + s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ /*empty*/ }) +} diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 8a490b3de..a26f54269 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,13 +1,26 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "atomicptr_machine", + out = "atomicptr_machine_unsafe.go", + package = "kvm", + prefix = "machine", + template = "//pkg/sync/atomicptr:generic_atomicptr", + types = { + "Value": "machine", + }, +) + go_library( name = "kvm", srcs = [ "address_space.go", "address_space_amd64.go", "address_space_arm64.go", + "atomicptr_machine_unsafe.go", "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", @@ -69,10 +82,17 @@ go_test( "kvm_amd64_test.go", "kvm_amd64_test.s", "kvm_arm64_test.go", + "kvm_safecopy_test.go", "kvm_test.go", "virtual_map_test.go", ], library = ":kvm", + # FIXME(gvisor.dev/issue/3374): Not working with all build systems. + nogo = False, + # cgo has to be disabled. We have seen libc that blocks all signals and + # calls mmap from pthread_create, but we use SIGSYS to trap mmap system + # calls. + pure = True, tags = [ "manual", "nogotsan", @@ -81,8 +101,10 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/hostarch", + "//pkg/memutil", "//pkg/ring0", "//pkg/ring0/pagetables", + "//pkg/safecopy", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", "//pkg/sentry/platform", diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index bb9967b9f..826997e77 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -61,6 +61,9 @@ var ( // This is called by bluepillHandler. savedHandler uintptr + // savedSigsysHandler is a pointer to the previos handler of the SIGSYS signals. + savedSigsysHandler uintptr + // dieTrampolineAddr is the address of dieTrampoline. dieTrampolineAddr uintptr ) diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index c2a1dca11..5d8358f64 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -32,6 +32,8 @@ // This is checked as the source of the fault. #define CLI $0xfa +#define SYS_MMAP 9 + // See bluepill.go. TEXT ·bluepill(SB),NOSPLIT,$0 begin: @@ -95,6 +97,31 @@ TEXT ·addrOfSighandler(SB), $0-8 MOVQ AX, ret+0(FP) RET +TEXT ·sigsysHandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel. + MOVQ $1, CX + CMPL CX, 0x8(SI) + JNE fallback + + MOVL CONTEXT_RAX(DX), CX + CMPL CX, $SYS_MMAP + JNE fallback + PUSHQ DX // First argument (context). + CALL ·seccompMmapHandler(SB) // Call the handler. + POPQ DX // Discard the argument. + RET +fallback: + // Jump to the previous signal handler. + XORQ CX, CX + MOVQ ·savedSigsysHandler(SB), AX + JMP AX + +// func addrOfSighandler() uintptr +TEXT ·addrOfSigsysHandler(SB), $0-8 + MOVQ $·sigsysHandler(SB), AX + MOVQ AX, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 PUSHQ BX // First argument (vCPU). diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 308f2a951..9690e3772 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -29,9 +29,12 @@ // Only limited use of the context is done in the assembly stub below, most is // done in the Go handlers. #define SIGINFO_SIGNO 0x0 +#define SIGINFO_CODE 0x8 #define CONTEXT_PC 0x1B8 #define CONTEXT_R0 0xB8 +#define SYS_MMAP 222 + // getTLS returns the value of TPIDR_EL0 register. TEXT ·getTLS(SB),NOSPLIT,$0-8 MRS TPIDR_EL0, R1 @@ -98,6 +101,37 @@ TEXT ·addrOfSighandler(SB), $0-8 MOVD R0, ret+0(FP) RET +// The arguments are the following: +// +// R0 - The signal number. +// R1 - Pointer to siginfo_t structure. +// R2 - Pointer to ucontext structure. +// +TEXT ·sigsysHandler(SB),NOSPLIT,$0 + // si_code should be SYS_SECCOMP. + MOVD SIGINFO_CODE(R1), R7 + CMPW $1, R7 + BNE fallback + + CMPW $SYS_MMAP, R8 + BNE fallback + + MOVD R2, 8(RSP) + BL ·seccompMmapHandler(SB) // Call the handler. + + RET + +fallback: + // Jump to the previous signal handler. + MOVD ·savedHandler(SB), R7 + B (R7) + +// func addrOfSighandler() uintptr +TEXT ·addrOfSigsysHandler(SB), $0-8 + MOVD $·sigsysHandler(SB), R0 + MOVD R0, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 // R0: Fake the old PC as caller diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 0f0c1e73b..e38ca05c0 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -193,36 +193,8 @@ func bluepillHandler(context unsafe.Pointer) { return } - // Increment the fault count. - atomic.AddUint32(&c.faults, 1) - - // For MMIO, the physical address is the first data item. - physical = uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) - if !ok { - c.die(bluepillArchContext(context), "invalid physical address") - return - } - - // We now need to fill in the data appropriately. KVM - // expects us to provide the result of the given MMIO - // operation in the runData struct. This is safe - // because, if a fault occurs here, the same fault - // would have occurred in guest mode. The kernel should - // not create invalid page table mappings. - data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1])) - length := (uintptr)((uint32)(c.runData.data[2])) - write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0 - for i := uintptr(0); i < length; i++ { - b := bytePtr(uintptr(virtual) + i) - if write { - // Write to the given address. - *b = data[i] - } else { - // Read from the given address. - data[i] = *b - } - } + c.die(bluepillArchContext(context), "exit_mmio") + return case _KVM_EXIT_IRQ_WINDOW_OPEN: bluepillStopGuest(c) case _KVM_EXIT_SHUTDOWN: diff --git a/pkg/sentry/platform/kvm/kvm_safecopy_test.go b/pkg/sentry/platform/kvm/kvm_safecopy_test.go new file mode 100644 index 000000000..9a87c9e6f --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_safecopy_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// FIXME(gvisor.dev/issue//6629): These tests don't pass on ARM64. +// +//go:build amd64 +// +build amd64 + +package kvm + +import ( + "fmt" + "os" + "testing" + "unsafe" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/safecopy" +) + +func testSafecopy(t *testing.T, mapSize uintptr, fileSize uintptr, testFunc func(t *testing.T, c *vCPU, addr uintptr)) { + memfd, err := memutil.CreateMemFD(fmt.Sprintf("kvm_test_%d", os.Getpid()), 0) + if err != nil { + t.Errorf("error creating memfd: %v", err) + } + + memfile := os.NewFile(uintptr(memfd), "kvm_test") + memfile.Truncate(int64(fileSize)) + kvmTest(t, nil, func(c *vCPU) bool { + const n = 10 + mappings := make([]uintptr, n) + defer func() { + for i := 0; i < n && mappings[i] != 0; i++ { + unix.RawSyscall( + unix.SYS_MUNMAP, + mappings[i], mapSize, 0) + } + }() + for i := 0; i < n; i++ { + addr, _, errno := unix.RawSyscall6( + unix.SYS_MMAP, + 0, + mapSize, + unix.PROT_READ|unix.PROT_WRITE, + unix.MAP_SHARED|unix.MAP_FILE, + uintptr(memfile.Fd()), + 0) + if errno != 0 { + t.Errorf("error mapping file: %v", errno) + } + mappings[i] = addr + testFunc(t, c, addr) + } + return false + }) +} + +func TestSafecopySigbus(t *testing.T) { + mapSize := uintptr(faultBlockSize) + fileSize := mapSize - hostarch.PageSize + buf := make([]byte, hostarch.PageSize) + testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) { + want := safecopy.BusError{addr + fileSize} + bluepill(c) + _, err := safecopy.CopyIn(buf, unsafe.Pointer(addr+fileSize)) + if err != want { + t.Errorf("expected error: got %v, want %v", err, want) + } + }) +} + +func TestSafecopy(t *testing.T) { + mapSize := uintptr(faultBlockSize) + fileSize := mapSize + testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) { + want := uint32(0x12345678) + bluepill(c) + _, err := safecopy.SwapUint32(unsafe.Pointer(addr+fileSize-8), want) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + bluepill(c) + val, err := safecopy.LoadUint32(unsafe.Pointer(addr + fileSize - 8)) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if val != want { + t.Errorf("incorrect value: got %x, want %x", val, want) + } + }) +} diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index d67563958..dcf34015d 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -17,15 +17,19 @@ package kvm import ( "fmt" "runtime" + gosync "sync" "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/seccomp" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" ) @@ -35,6 +39,9 @@ type machine struct { // fd is the vm fd. fd int + // machinePoolIndex is the index in the machinePool array. + machinePoolIndex uint32 + // nextSlot is the next slot for setMemoryRegion. // // This must be accessed atomically. If nextSlot is ^uint32(0), then @@ -192,6 +199,10 @@ func (m *machine) newVCPU() *vCPU { return c // Done. } +// readOnlyGuestRegions contains regions that have to be mapped read-only into +// the guest physical address space. Right now, it is used on arm64 only. +var readOnlyGuestRegions []region + // newMachine returns a new VM context. func newMachine(vm int) (*machine, error) { // Create the machine. @@ -227,6 +238,10 @@ func newMachine(vm int) (*machine, error) { m.upperSharedPageTables.MarkReadOnlyShared() m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress) + // Install seccomp rules to trap runtime mmap system calls. They will + // be handled by seccompMmapHandler. + seccompMmapRules(m) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -241,32 +256,11 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) - var physicalRegionsReadOnly []physicalRegion - var physicalRegionsAvailable []physicalRegion - - physicalRegionsReadOnly = rdonlyRegionsForSetMem() - physicalRegionsAvailable = availableRegionsForSetMem() - - // Map all read-only regions. - for _, r := range physicalRegionsReadOnly { - m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) - } - // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to // ensure successful vCPU entry. - applyVirtualRegions(func(vr virtualRegion) { - if excludeVirtualRegion(vr) { - return // skip region. - } - - for _, r := range physicalRegionsReadOnly { - if vr.virtual == r.virtual { - return - } - } - + mapRegion := func(vr region, flags uint32) { for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -280,9 +274,32 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) + m.mapPhysical(physical, length, physicalRegions, flags) virtual += length } + } + + for _, vr := range readOnlyGuestRegions { + mapRegion(vr, _KVM_MEM_READONLY) + } + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + for _, r := range readOnlyGuestRegions { + if vr.virtual == r.virtual { + return + } + } + // Take into account that the stack can grow down. + if vr.filename == "[stack]" { + vr.virtual -= 1 << 20 + vr.length += 1 << 20 + } + + mapRegion(vr.region, 0) + }) // Initialize architecture state. @@ -352,6 +369,10 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg func (m *machine) Destroy() { runtime.SetFinalizer(m, nil) + machinePoolMu.Lock() + machinePool[m.machinePoolIndex].Store(nil) + machinePoolMu.Unlock() + // Destroy vCPUs. for _, c := range m.vCPUsByID { if c == nil { @@ -683,3 +704,72 @@ func (c *vCPU) setSystemTimeLegacy() error { } } } + +const machinePoolSize = 16 + +// machinePool is enumerated from the seccompMmapHandler signal handler +var ( + machinePool [machinePoolSize]machineAtomicPtr + machinePoolLen uint32 + machinePoolMu sync.Mutex + seccompMmapRulesOnce gosync.Once +) + +func sigsysHandler() +func addrOfSigsysHandler() uintptr + +// seccompMmapRules adds seccomp rules to trap mmap system calls that will be +// handled in seccompMmapHandler. +func seccompMmapRules(m *machine) { + seccompMmapRulesOnce.Do(func() { + // Install the handler. + if err := safecopy.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) + } + rules := []seccomp.RuleSet{} + rules = append(rules, []seccomp.RuleSet{ + // Trap mmap system calls and handle them in sigsysGoHandler + { + Rules: seccomp.SyscallRules{ + unix.SYS_MMAP: { + { + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + /* MAP_DENYWRITE is ignored and used only for filtering. */ + seccomp.MaskedEqual(unix.MAP_DENYWRITE, 0), + }, + }, + }, + Action: linux.SECCOMP_RET_TRAP, + }, + }...) + instrs, err := seccomp.BuildProgram(rules, linux.SECCOMP_RET_ALLOW, linux.SECCOMP_RET_ALLOW) + if err != nil { + panic(fmt.Sprintf("failed to build rules: %v", err)) + } + // Perform the actual installation. + if err := seccomp.SetFilter(instrs); err != nil { + panic(fmt.Sprintf("failed to set filter: %v", err)) + } + }) + + machinePoolMu.Lock() + n := atomic.LoadUint32(&machinePoolLen) + i := uint32(0) + for ; i < n; i++ { + if machinePool[i].Load() == nil { + break + } + } + if i == n { + if i == machinePoolSize { + machinePoolMu.Unlock() + panic("machinePool is full") + } + atomic.AddUint32(&machinePoolLen, 1) + } + machinePool[i].Store(m) + m.machinePoolIndex = i + machinePoolMu.Unlock() +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index a96634381..ab1e036b7 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -309,22 +309,6 @@ func loadByte(ptr *byte) byte { return *ptr } -// prefaultFloatingPointState touches each page of the floating point state to -// be sure that its physical pages are mapped. -// -// Otherwise the kernel can trigger KVM_EXIT_MMIO and an instruction that -// triggered a fault will be emulated by the kvm kernel code, but it can't -// emulate instructions like xsave and xrstor. -// -//go:nosplit -func prefaultFloatingPointState(data *fpu.State) { - size := len(*data) - for i := 0; i < size; i += hostarch.PageSize { - loadByte(&(*data)[i]) - } - loadByte(&(*data)[size-1]) -} - // SwitchToUser unpacks architectural-details. func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. @@ -355,11 +339,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) // allocations occur. entersyscall() bluepill(c) - // The root table physical page has to be mapped to not fault in iret - // or sysret after switching into a user address space. sysret and - // iret are in the upper half that is global and already mapped. - switchOpts.PageTables.PrefaultRootTable() - prefaultFloatingPointState(switchOpts.FloatingPointState) vector = c.CPU.SwitchToUser(switchOpts) exitsyscall() @@ -522,3 +501,7 @@ func (m *machine) getNewVCPU() *vCPU { } return nil } + +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index de798bb2c..fbacea9ad 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -161,3 +161,15 @@ func (c *vCPU) getSystemRegisters(sregs *systemRegs) unix.Errno { } return 0 } + +//go:nosplit +func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) { + ctx := bluepillArchContext(context) + + // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters. + addr, _, e := unix.RawSyscall6(uintptr(ctx.Rax), uintptr(ctx.Rdi), uintptr(ctx.Rsi), + uintptr(ctx.Rdx), uintptr(ctx.R10)|unix.MAP_DENYWRITE, uintptr(ctx.R8), uintptr(ctx.R9)) + ctx.Rax = uint64(addr) + + return addr, uintptr(ctx.Rsi), e +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 7937a8481..08d98c479 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -110,18 +110,128 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { return phyRegions } +// archPhysicalRegions fills readOnlyGuestRegions and allocates separate +// physical regions form them. +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return // skip region. + } + if !vr.accessType.Write { + readOnlyGuestRegions = append(readOnlyGuestRegions, vr.region) + } + }) + + rdRegions := readOnlyGuestRegions[:] + + // Add an unreachable region. + rdRegions = append(rdRegions, region{ + virtual: 0xffffffffffffffff, + length: 0, + }) + + var regions []physicalRegion + addValidRegion := func(r *physicalRegion, virtual, length uintptr) { + if length == 0 { + return + } + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), + }) + } + i := 0 + for _, pr := range physicalRegions { + start := pr.virtual + end := pr.virtual + pr.length + for start < end { + rdRegion := rdRegions[i] + rdStart := rdRegion.virtual + rdEnd := rdRegion.virtual + rdRegion.length + if rdEnd <= start { + i++ + continue + } + if rdStart > start { + newEnd := rdStart + if end < rdStart { + newEnd = end + } + addValidRegion(&pr, start, newEnd-start) + start = rdStart + continue + } + if rdEnd < end { + addValidRegion(&pr, start, rdEnd-start) + start = rdEnd + continue + } + addValidRegion(&pr, start, end-start) + start = end + } + } + + return regions +} + // Get all available physicalRegions. -func availableRegionsForSetMem() (phyRegions []physicalRegion) { - var excludeRegions []region +func availableRegionsForSetMem() []physicalRegion { + var excludedRegions []region applyVirtualRegions(func(vr virtualRegion) { if !vr.accessType.Write { - excludeRegions = append(excludeRegions, vr.region) + excludedRegions = append(excludedRegions, vr.region) } }) - phyRegions = computePhysicalRegions(excludeRegions) + // Add an unreachable region. + excludedRegions = append(excludedRegions, region{ + virtual: 0xffffffffffffffff, + length: 0, + }) - return phyRegions + var regions []physicalRegion + addValidRegion := func(r *physicalRegion, virtual, length uintptr) { + if length == 0 { + return + } + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), + }) + } + i := 0 + for _, pr := range physicalRegions { + start := pr.virtual + end := pr.virtual + pr.length + for start < end { + er := excludedRegions[i] + excludeEnd := er.virtual + er.length + excludeStart := er.virtual + if excludeEnd < start { + i++ + continue + } + if excludeStart < start { + start = excludeEnd + i++ + continue + } + rend := excludeStart + if rend > end { + rend = end + } + addValidRegion(&pr, start, rend-start) + start = excludeEnd + } + } + + return regions } // nonCanonical generates a canonical address return. diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 1a4a9ce7d..7e8e19dcb 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -333,3 +333,15 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) } } + +//go:nosplit +func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) { + ctx := bluepillArchContext(context) + + // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters. + addr, _, e := unix.RawSyscall6(uintptr(ctx.Regs[8]), uintptr(ctx.Regs[0]), uintptr(ctx.Regs[1]), + uintptr(ctx.Regs[2]), uintptr(ctx.Regs[3])|unix.MAP_DENYWRITE, uintptr(ctx.Regs[4]), uintptr(ctx.Regs[5])) + ctx.Regs[0] = uint64(addr) + + return addr, uintptr(ctx.Regs[1]), e +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index cc3a1253b..cf3a4e7c9 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -171,3 +171,46 @@ func (c *vCPU) setSignalMask() error { return nil } + +// seccompMmapHandler is a signal handler for runtime mmap system calls +// that are trapped by seccomp. +// +// It executes the mmap syscall with specified arguments and maps a new region +// to the guest. +// +//go:nosplit +func seccompMmapHandler(context unsafe.Pointer) { + addr, length, errno := seccompMmapSyscall(context) + if errno != 0 { + return + } + + for i := uint32(0); i < atomic.LoadUint32(&machinePoolLen); i++ { + m := machinePool[i].Load() + if m == nil { + continue + } + + // Map the new region to the guest. + vr := region{ + virtual: addr, + length: length, + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { + physical, length, ok := translateToPhysical(virtual) + if !ok { + // This must be an invalid region that was + // knocked out by creation of the physical map. + return + } + if virtual+length > vr.virtual+vr.length { + // Cap the length to the end of the area. + length = vr.virtual + vr.length - virtual + } + + // Ensure the physical range is mapped. + m.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) + virtual += length + } + } +} diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index d812e6c26..9864d1258 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -168,6 +168,9 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica } addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd) + // Do arch-specific actions on physical regions. + physicalRegions = archPhysicalRegions(physicalRegions) + // Dump our all physical regions. for _, r := range physicalRegions { log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)", diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go index 6d0ba8252..346a10043 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -30,8 +30,8 @@ import ( func TLSWorks() bool // SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *arch.Registers, fn func()) { - regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) +func SetTestTarget(regs *arch.Registers, fn uintptr) { + regs.Pc = uint64(fn) } // SetTouchTarget sets rax appropriately. diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 7348c29a5..42876245a 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -28,6 +28,11 @@ TEXT ·Getpid(SB),NOSPLIT,$0 SVC RET +TEXT ·AddrOfGetpid(SB),NOSPLIT,$0-8 + MOVD $·Getpid(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·Touch(SB),NOSPLIT,$0 start: MOVD 0(R8), R1 @@ -35,21 +40,41 @@ start: SVC B start +TEXT ·AddrOfTouch(SB),NOSPLIT,$0-8 + MOVD $·Touch(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·HaltLoop(SB),NOSPLIT,$0 start: HLT B start +TEXT ·AddOfHaltLoop(SB),NOSPLIT,$0-8 + MOVD $·HaltLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + // This function simulates a loop of syscall. TEXT ·SyscallLoop(SB),NOSPLIT,$0 start: SVC B start +TEXT ·AddrOfSyscallLoop(SB),NOSPLIT,$0-8 + MOVD $·SyscallLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·SpinLoop(SB),NOSPLIT,$0 start: B start +TEXT ·AddrOfSpinLoop(SB),NOSPLIT,$0-8 + MOVD $·SpinLoop(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·TLSWorks(SB),NOSPLIT,$0-8 NO_LOCAL_POINTERS MOVD $0x6789, R5 @@ -125,6 +150,11 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 SVC RET // never reached +TEXT ·AddrOfTwiddleRegsSyscall(SB),NOSPLIT,$0-8 + MOVD $·TwiddleRegsSyscall(SB), R0 + MOVD R0, ret+0(FP) + RET + TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 TWIDDLE_REGS() MSR R10, TPIDR_EL0 @@ -132,3 +162,8 @@ TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 // Branch to Register branches unconditionally to an address in <Rn>. JMP (R6) // <=> br x6, must fault RET // never reached + +TEXT ·AddrOfTwiddleRegsFault(SB),NOSPLIT,$0-8 + MOVD $·TwiddleRegsFault(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD index 943fa180d..35feb969f 100644 --- a/pkg/sentry/seccheck/BUILD +++ b/pkg/sentry/seccheck/BUILD @@ -8,6 +8,8 @@ go_fieldenum( name = "seccheck_fieldenum", srcs = [ "clone.go", + "execve.go", + "exit.go", "task.go", ], out = "seccheck_fieldenum.go", @@ -29,6 +31,8 @@ go_library( name = "seccheck", srcs = [ "clone.go", + "execve.go", + "exit.go", "seccheck.go", "seccheck_fieldenum.go", "seqatomic_checkerslice_unsafe.go", diff --git a/pkg/sentry/seccheck/execve.go b/pkg/sentry/seccheck/execve.go new file mode 100644 index 000000000..f36e0730e --- /dev/null +++ b/pkg/sentry/seccheck/execve.go @@ -0,0 +1,65 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// ExecveInfo contains information used by the Execve checkpoint. +// +// +fieldenum Execve +type ExecveInfo struct { + // Invoker identifies the invoking thread. + Invoker TaskInfo + + // Credentials are the invoking thread's credentials. + Credentials *auth.Credentials + + // BinaryPath is a path to the executable binary file being switched to in + // the mount namespace in which it was opened. + BinaryPath string + + // Argv is the new process image's argument vector. + Argv []string + + // Env is the new process image's environment variables. + Env []string + + // BinaryMode is the executable binary file's mode. + BinaryMode uint16 + + // BinarySHA256 is the SHA-256 hash of the executable binary file. + // + // Note that this requires reading the entire file into memory, which is + // likely to be extremely slow. + BinarySHA256 [32]byte +} + +// ExecveReq returns fields required by the Execve checkpoint. +func (s *state) ExecveReq() ExecveFieldSet { + return s.execveReq.Load() +} + +// Execve is called at the Execve checkpoint. +func (s *state) Execve(ctx context.Context, mask ExecveFieldSet, info *ExecveInfo) error { + for _, c := range s.getCheckers() { + if err := c.Execve(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/exit.go b/pkg/sentry/seccheck/exit.go new file mode 100644 index 000000000..69cb6911c --- /dev/null +++ b/pkg/sentry/seccheck/exit.go @@ -0,0 +1,57 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" +) + +// ExitNotifyParentInfo contains information used by the ExitNotifyParent +// checkpoint. +// +// +fieldenum ExitNotifyParent +type ExitNotifyParentInfo struct { + // Exiter identifies the exiting thread. Note that by the checkpoint's + // definition, Exiter.ThreadID == Exiter.ThreadGroupID and + // Exiter.ThreadStartTime == Exiter.ThreadGroupStartTime, so requesting + // ThreadGroup* fields is redundant. + Exiter TaskInfo + + // ExitStatus is the exiting thread group's exit status, as reported + // by wait*(). + ExitStatus linux.WaitStatus +} + +// ExitNotifyParentReq returns fields required by the ExitNotifyParent +// checkpoint. +func (s *state) ExitNotifyParentReq() ExitNotifyParentFieldSet { + return s.exitNotifyParentReq.Load() +} + +// ExitNotifyParent is called at the ExitNotifyParent checkpoint. +// +// The ExitNotifyParent checkpoint occurs when a zombied thread group leader, +// not waiting for exit acknowledgement from a non-parent ptracer, becomes the +// last non-dead thread in its thread group and notifies its parent of its +// exiting. +func (s *state) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info *ExitNotifyParentInfo) error { + for _, c := range s.getCheckers() { + if err := c.ExitNotifyParent(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go index b6c9d44ce..e13274096 100644 --- a/pkg/sentry/seccheck/seccheck.go +++ b/pkg/sentry/seccheck/seccheck.go @@ -29,6 +29,8 @@ type Point uint // PointX represents the checkpoint X. const ( PointClone Point = iota + PointExecve + PointExitNotifyParent // Add new Points above this line. pointLength @@ -47,6 +49,8 @@ const ( // registered concurrently with invocations of checkpoints). type Checker interface { Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error + Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error + ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error } // CheckerDefaults may be embedded by implementations of Checker to obtain @@ -58,6 +62,16 @@ func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info Clone return nil } +// Execve implements Checker.Execve. +func (CheckerDefaults) Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error { + return nil +} + +// ExitNotifyParent implements Checker.ExitNotifyParent. +func (CheckerDefaults) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error { + return nil +} + // CheckerReq indicates what checkpoints a corresponding Checker runs at, and // what information it requires at those checkpoints. type CheckerReq struct { @@ -69,7 +83,9 @@ type CheckerReq struct { // All of the following fields indicate what fields in the corresponding // XInfo struct will be requested at the corresponding checkpoint. - Clone CloneFields + Clone CloneFields + Execve ExecveFields + ExitNotifyParent ExitNotifyParentFields } // Global is the method receiver of all seccheck functions. @@ -101,7 +117,9 @@ type state struct { // corresponding XInfo struct have been requested by any registered // checker, are accessed using atomic memory operations, and are mutated // with registrationMu locked. - cloneReq CloneFieldSet + cloneReq CloneFieldSet + execveReq ExecveFieldSet + exitNotifyParentReq ExitNotifyParentFieldSet } // AppendChecker registers the given Checker to execute at checkpoints. The @@ -110,7 +128,11 @@ type state struct { func (s *state) AppendChecker(c Checker, req *CheckerReq) { s.registrationMu.Lock() defer s.registrationMu.Unlock() + s.cloneReq.AddFieldsLoadable(req.Clone) + s.execveReq.AddFieldsLoadable(req.Execve) + s.exitNotifyParentReq.AddFieldsLoadable(req.ExitNotifyParent) + s.appendCheckerLocked(c) for _, p := range req.Points { word, bit := p/32, p%32 diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index ea56f39c1..0f6e576a9 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID { } // Action implements stack.Target.Action. -func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go index 82b6df18c..7b9c2a4db 100644 --- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go +++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go @@ -37,6 +37,8 @@ func (p *AtomicPtr) loadPtr(v *Value) { // Load returns the value set by the most recent Store. It returns nil if there // has been no previous call to Store. +// +//go:nosplit func (p *AtomicPtr) Load() *Value { return (*Value)(atomic.LoadPointer(&p.ptr)) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 25f5a52e3..dda473e48 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -426,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -466,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -549,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -576,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -717,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -744,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -841,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -940,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index dab99d00d..e2d2cf907 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1100,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -1183,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 068dab7ce..4fb7e9adb 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // // Precondition: cn.mu must be held. -func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) { + if pkt.TransportProtocolNumber != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. @@ -209,27 +215,38 @@ type bucket struct { tuples tupleList } +func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { + return tcpHeader, true + } + case header.UDPProtocolNumber: + if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { + return udpHeader, true + } + } + + return nil, false +} + // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, &tcpip.ErrUnknownProtocol{} - } - - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ srcAddr: netHeader.SourceAddress(), - srcPort: tcpHeader.SourcePort(), + srcPort: transportHeader.SourcePort(), dstAddr: netHeader.DestinationAddress(), - dstPort: tcpHeader.DestinationPort(), - transProto: netHeader.TransportProtocol(), + dstPort: transportHeader.DestinationPort(), + transProto: pkt.TransportProtocolNumber, netProto: pkt.NetworkProtocolNumber, }, nil } @@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/6168): Support UDP. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return false } @@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return false - } // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be // validated if checksum offloading is off. It may require IP defrag if the @@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { switch hook { case Prerouting, Output: - if conn.manip == manipDestination { - switch dir { - case dirOriginal: - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - case dirReply: - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - - updateSRCFields = true - } + if conn.manip == manipDestination && dir == dirOriginal { + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr + pkt.NatDone = true + } else if conn.manip == manipSource && dir == dirReply { + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource { - switch dir { - case dirOriginal: - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - - updateSRCFields = true - case dirReply: - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr - } + if conn.manip == manipSource && dir == dirOriginal { + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + updateSRCFields = true + pkt.NatDone = true + } else if conn.manip == manipDestination && dir == dirReply { + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + updateSRCFields = true pkt.NatDone = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } + if !pkt.NatDone { return false } @@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { fullChecksum := false updatePseudoHeader := false switch hook { - case Prerouting, Input: + case Prerouting: + // Packet came from outside the stack so it must have a checksum set + // already. + fullChecksum = true + updatePseudoHeader = true + case Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { fullChecksum = true @@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { rewritePacket( netHeader, - tcpHeader, + transportHeader, updateSRCFields, fullChecksum, updatePseudoHeader, @@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) return false } @@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } - // We only track TCP connections. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber, header.UDPProtocolNumber: + default: + // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable + // connections. return } @@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) ct.insertConn(conn) } @@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ srcPort: epID.LocalPort, dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, - transProto: header.TCPProtocolNumber, + transProto: transProto, netProto: netProto, } conn, _ := ct.connForTID(tid) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f152c0d83..74c9075b4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -264,12 +264,62 @@ const ( chainReturn ) -// Check runs pkt through the rules for hook. It returns true when the packet +// CheckPrerouting performs the prerouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool { + return it.check(Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) +} + +// CheckInput performs the input hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { + return it.check(Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) +} + +// CheckForward performs the forward hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool { + return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) +} + +// CheckOutput performs the output hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool { + return it.check(Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// CheckPostrouting performs the postrouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool { + return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// check runs pkt through the rules for hook. It returns true when the packet // should continue traversing the network stack and false when it should be // dropped. // -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -300,7 +350,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -311,7 +361,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v { case RuleAccept: continue case RuleDrop: @@ -375,19 +425,35 @@ func (it *IPTables) startReaper(interval time.Duration) { }() } -// CheckPackets runs pkts through the rules for hook and returns a map of packets that -// should not go forward. +// CheckOutputPackets performs the output hook on the packets. // -// Preconditions: -// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// * pkt.NetworkHeader is not nil. +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return it.checkPackets(Output, pkts, r, outNicName) +} + +// CheckPostroutingPackets performs the postrouting hook on the packets. +// +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return it.checkPackets(Postrouting, pkts, r, outNicName) +} + +// checkPackets runs pkts through the rules for hook and returns a map of +// packets that should not go forward. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { + if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -407,11 +473,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -428,7 +494,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -454,7 +520,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -477,16 +543,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) + return rule.Target.Action(pkt, &it.connections, hook, r, addressEP) } // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, &tcpip.ErrNotConnected{} } - return it.connections.originalDst(epID, netProto) + return it.connections.originalDst(epID, netProto, transProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 96cc899bb..e8806ebdb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,7 +79,7 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleReturn, 0 } @@ -97,7 +97,7 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -117,6 +117,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. + var address tcpip.Address switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { @@ -125,7 +126,8 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r address = header.IPv6Loopback } case Prerouting: - // No-op, as address is already set correctly. + // addressEP is expected to be set for the prerouting hook. + address = addressEP.MainAddress().Address default: panic("redirect target is supported only on output and prerouting hooks") } @@ -180,7 +182,7 @@ type SNATTarget struct { } // Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if st.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -206,34 +208,28 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - header.UDP(pkt.TransportHeader().View()), - true, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - st.Port, - st.Addr, - ) - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 + port := st.Port + + if port == 0 { + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + if port == 0 { + port = header.UDP(pkt.TransportHeader().View()).SourcePort() + } + case header.TCPProtocolNumber: + if port == 0 { + port = header.TCP(pkt.TransportHeader().View()).SourcePort() + } } + } - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) - } - default: - return RuleDrop, 0 + // Set up conection for matching NAT rule. Only the first packet of the + // connection comes here. Other packets will be manipulated in connection + // tracking. + // + // Does nothing if the protocol does not support connection tracking. + if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil { + ct.handlePacket(pkt, hook, r) } return RuleAccept, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 66e5f22ac..976194124 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -352,5 +352,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index b9280c2de..bf248ef20 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -335,9 +335,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // tell if a noop connection should be inserted at Input hook. Once conntrack // redefines the manipulation field as mutable, we won't need the special noop // connection. - if pk.NatDone { - newPk.NatDone = true - } + newPk.NatDone = pk.NatDone return newPk } @@ -347,7 +345,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // The returned packet buffer will have the network and transport headers // set if the original packet buffer did. func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer { - newPkt := NewPacketBuffer(PacketBufferOptions{ + newPk := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: reservedHeaderBytes, Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(), IsForwardedPacket: true, @@ -355,21 +353,28 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu { consumeBytes := pk.NetworkHeader().View().Size() - if _, consumed := newPkt.NetworkHeader().Consume(consumeBytes); !consumed { + if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed { panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes)) } - newPkt.NetworkProtocolNumber = pk.NetworkProtocolNumber + newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber } { consumeBytes := pk.TransportHeader().View().Size() - if _, consumed := newPkt.TransportHeader().Consume(consumeBytes); !consumed { + if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed { panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes)) } - newPkt.TransportProtocolNumber = pk.TransportProtocolNumber + newPk.TransportProtocolNumber = pk.TransportProtocolNumber } - return newPkt + // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to + // maintain this flag in the packet. Currently conntrack needs this flag to + // tell if a noop connection should be inserted at Input hook. Once conntrack + // redefines the manipulation field as mutable, we won't need the special noop + // connection. + newPk.NatDone = pk.NatDone + + return newPk } // headerInfo stores metadata about a header in a packet. diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c5e896295..d45a2c05c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1269,6 +1269,8 @@ type TransportProtocolNumber uint32 type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. +// +// +stateify savable type StatCounter struct { count atomicbitops.AlignedAtomicUint64 } @@ -1995,6 +1997,8 @@ type Stats struct { } // ReceiveErrors collects packet receive errors within transport endpoint. +// +// +stateify savable type ReceiveErrors struct { // ReceiveBufferOverflow is the number of received packets dropped // due to the receive buffer being full. @@ -2012,8 +2016,10 @@ type ReceiveErrors struct { ChecksumErrors StatCounter } -// SendErrors collects packet send errors within the transport layer for -// an endpoint. +// SendErrors collects packet send errors within the transport layer for an +// endpoint. +// +// +stateify savable type SendErrors struct { // SendToNetworkFailed is the number of packets failed to be written to // the network endpoint. @@ -2024,6 +2030,8 @@ type SendErrors struct { } // ReadErrors collects segment read errors from an endpoint read call. +// +// +stateify savable type ReadErrors struct { // ReadClosed is the number of received packet drops because the endpoint // was shutdown for read. @@ -2039,6 +2047,8 @@ type ReadErrors struct { } // WriteErrors collects packet write errors from an endpoint write call. +// +// +stateify savable type WriteErrors struct { // WriteClosed is the number of packet drops because the endpoint // was shutdown for write. @@ -2054,6 +2064,8 @@ type WriteErrors struct { } // TransportEndpointStats collects statistics about the endpoint. +// +// +stateify savable type TransportEndpointStats struct { // PacketsReceived is the number of successful packet receives. PacketsReceived StatCounter diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 181ef799e..7c998eaae 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -34,12 +34,16 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 28b49c6be..bdf4a64b9 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -15,19 +15,24 @@ package iptables_test import ( + "bytes" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) type inputIfNameMatcher struct { @@ -1156,3 +1161,286 @@ func TestInputHookWithLocalForwarding(t *testing.T) { }) } } + +func TestSNAT(t *testing.T) { + const listenPort = 8080 + + type endpointAndAddresses struct { + serverEP tcpip.Endpoint + serverAddr tcpip.Address + serverReadableCH chan struct{} + + clientEP tcpip.Endpoint + clientAddr tcpip.Address + clientReadableCH chan struct{} + + nattedClientAddr tcpip.Address + } + + newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + t.Cleanup(func() { + wq.EventUnregister(&we) + }) + + ep, err := s.NewEndpoint(transProto, netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) + } + t.Cleanup(ep.Close) + + return ep, ch + } + + tests := []struct { + name string + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses + }{ + { + name: "IPv4 host1 server with host2 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + ipt := routerStack.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} + filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err) + } + + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + + nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, + } + }, + }, + { + name: "IPv6 host1 server with host2 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + ipt := routerStack.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} + filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err) + } + + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + + nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, + } + }, + }, + } + + subTests := []struct { + name string + proto tcpip.TransportProtocolNumber + expectedConnectErr tcpip.Error + setupServer func(t *testing.T, ep tcpip.Endpoint) + setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + needRemoteAddr bool + }{ + { + name: "UDP", + proto: udp.ProtocolNumber, + expectedConnectErr: nil, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + if err := ep.Connect(clientAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) + } + return nil, nil + }, + needRemoteAddr: true, + }, + { + name: "TCP", + proto: tcp.ProtocolNumber, + expectedConnectErr: &tcpip.ErrConnectStarted{}, + setupServer: func(t *testing.T, ep tcpip.Endpoint) { + t.Helper() + + if err := ep.Listen(1); err != nil { + t.Fatalf("ep.Listen(1): %s", err) + } + }, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + var addr tcpip.FullAddress + for { + newEP, wq, err := ep.Accept(&addr) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Accept(_): %s", err) + } + if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( + "NIC", + )); diff != "" { + t.Errorf("accepted address mismatch (-want +got):\n%s", diff) + } + + we, newCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + return newEP, newCH + } + }, + needRemoteAddr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } + } + nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr} + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + nattedClientAddr.Port = addr.Port + } + + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch + } + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + var r bytes.Reader + r.Reset(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(&r, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } + } + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + var buf bytes.Buffer + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } + + { + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + read(serverCH, serverEP, data, nattedClientAddr) + } + + { + data := []byte{5, 6, 7, 8, 9, 10, 11, 12} + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 947bcc7b1..c69410859 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -40,6 +40,14 @@ const ( Host2NICID = 4 ) +// Common NIC names used by tests. +const ( + Host1NICName = "host1NIC" + RouterNIC1Name = "routerNIC1" + RouterNIC2Name = "routerNIC2" + Host2NICName = "host2NIC" +) + // Common link addresses used by tests. const ( LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") @@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2) routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4) - if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err) + { + opts := stack.NICOptions{Name: Host1NICName} + if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil { + t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err) + { + opts := stack.NICOptions{Name: RouterNIC1Name} + if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err) + { + opts := stack.NICOptions{Name: RouterNIC2Name} + if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err) + } } - if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) + { + opts := stack.NICOptions{Name: Host2NICName} + if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil { + t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err) + } } if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index bbc0e3ecc..4718ec4ec 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/waiter", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 1e519085d..bb0db9f70 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,7 @@ package icmp import ( + "fmt" "io" "time" @@ -24,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -35,15 +38,6 @@ type icmpPacket struct { receivedAt time.Time `state:".(int64)"` } -type endpointState int - -const ( - stateInitial endpointState = iota - stateBound - stateConnected - stateClosed -) - // endpoint represents an ICMP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -51,14 +45,17 @@ const ( // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -70,38 +67,23 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - state endpointState - route *stack.Route `state:"manual"` - ttl uint8 - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool + ident uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, - state: stateInitial, uniqueID: s.UniqueID(), } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) + ep.net.Init(s, netProto, transProto, &ep.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -128,35 +110,40 @@ func (e *endpoint) Abort() { // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { - e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { - case stateBound, stateConnected: - bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) - } + notify := func() bool { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + return false + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + info := e.net.Info() + info.ID.LocalPort = e.ident + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice())) + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } - // Close the receive list and drain it. - e.rcvMu.Lock() - e.rcvClosed = true - e.rcvBufSize = 0 - for !e.rcvList.Empty() { - p := e.rcvList.Front() - e.rcvList.Remove(p) - } - e.rcvMu.Unlock() + e.net.Shutdown() + e.net.Close() - if e.route != nil { - e.route.Release() - e.route = nil - } - - // Update the state. - e.state = stateClosed + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } - e.mu.Unlock() + return true + }() - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + if notify { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + } } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. @@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -214,13 +201,12 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // // Returns true for retry if preparation should be retried. // +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.state { - case stateInitial: - case stateConnected: +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - - case stateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { - return 0, err + return network.WriteContext{}, 0, err } if !retry { @@ -298,36 +272,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } } - route := e.route - if to != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := to.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - - dst, netProto, err := e.checkV4MappedLocked(*to) - if err != nil { - return 0, err - } - - // Find the endpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) - if err != nil { - return 0, err - } - defer r.Release() + ctx, err := e.net.AcquireContextForWrite(opts) + return ctx, e.ident, err +} - route = r +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + ctx, ident, err := e.prepareForWrite(opts) + if err != nil { + return 0, err } + defer ctx.Release() // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) @@ -335,17 +289,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return 0, &tcpip.ErrBadBuffer{} } - var err tcpip.Error - switch e.NetProto { + switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) + if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } case header.IPv6ProtocolNumber: - err = send6(route, e.ID.LocalPort, v, e.ttl) - } - - if err != nil { - return 0, err + if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } + default: + panic(fmt.Sprintf("unhandled network protocol = %d", netProto)) } return int64(len(v)), nil @@ -358,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool { return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt sets a socket option. -func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return nil +// SetSockOpt implements tcpip.Endpoint. +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + return e.net.SetSockOpt(opt) } -// SetSockOptInt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - } - return nil + return e.net.SetSockOptInt(opt, v) } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -388,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.rcvMu.Lock() - v := int(e.ttl) - e.rcvMu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +// GetSockOpt implements tcpip.Endpoint. +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { +func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength), }) - pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber @@ -427,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest - icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - pkt.Data().AppendView(data) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V4.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { +func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength), }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) @@ -469,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest pkt.Data().AppendView(data) dataRange := pkt.Data().AsRange() icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpv6, - Src: r.LocalAddress(), - Dst: r.RemoteAddress(), + Src: src, + Dst: dst, PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V6.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() + return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() tcpip.Error { return &tcpip.ErrNotSupported{} @@ -516,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - localPort := uint16(0) - switch e.state { - case stateInitial: - case stateBound, stateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.ident - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + nextID, err := e.registerWithStack(netProto, nextID) + if err != nil { + return err } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } - - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress(), - LocalPort: localPort, - RemoteAddress: r.RemoteAddress(), - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - id, err = e.registerWithStack(nicID, netProtos, id) + e.ident = nextID.LocalPort + return nil + }) if err != nil { - r.Release() return err } - e.ID = id - e.route = r - e.RegisterNICID = nicID - - e.state = stateConnected - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -586,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - e.shutdownFlags |= flags - if e.state != stateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } + + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } } if flags&tcpip.ShutdownRead != 0 { @@ -616,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) - return id, err + return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) + err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -645,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err := e.registerWithStack(boundNetProto, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, err = e.registerWithStack(addr.NIC, netProtos, id) + e.ident = id.LocalPort + return nil + }) if err != nil { return err } - e.ID = id - e.RegisterNICID = addr.NIC - - // Mark endpoint as bound. - e.state = stateBound - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -692,7 +571,7 @@ func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || - e.stack.IsSubnetBroadcast(nicID, e.NetProto, addr) + e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr) } // Bind binds the endpoint to a specific local address and port. @@ -705,15 +584,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - err := e.bindLocked(addr) - if err != nil { - return err - } - - e.BindNICID = addr.NIC - e.BindAddr = addr.Addr - - return nil + return e.bindLocked(addr) } // GetLocalAddress returns the address to which the endpoint is bound. @@ -721,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.ident + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -733,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { - return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} + if addr, connected := e.net.GetRemoteAddress(); connected { + return addr, nil } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness returns the current readiness of the endpoint. For example, if @@ -766,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. - switch e.NetProto { + switch e.net.NetProto() { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { @@ -840,9 +705,9 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + defer e.mu.RUnlock() + ret := e.net.Info() + ret.ID.LocalPort = e.ident return &ret } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index b8b839e4a..dfe453ff9 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,13 @@ package icmp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.thaw() + + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - if e.state != stateBound && e.state != stateConnected { - return - } - - var err tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + var err tcpip.Error + info := e.net.Info() + info.ID.LocalPort = e.ident + info.ID, err = e.registerWithStack(info.NetProto, info.ID) if err != nil { - panic(err) + panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err)) } - - e.ID.LocalAddress = e.route.LocalAddress() - } else if len(e.ID.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) - if err != nil { - panic(err) + e.ident = info.ID.LocalPort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD index b1edce39b..3818cb04e 100644 --- a/pkg/tcpip/transport/internal/network/BUILD +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -9,6 +9,7 @@ go_library( "endpoint_state.go", ], visibility = [ + "//pkg/tcpip/transport/icmp:__pkg__", "//pkg/tcpip/transport/raw:__pkg__", "//pkg/tcpip/transport/udp:__pkg__", ], diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index e4a64e191..689427d53 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -67,7 +67,7 @@ type endpoint struct { waiterQueue *waiter.Queue cooked bool ops tcpip.SocketOptions - stats tcpip.TransportEndpointStats `state:"nosave"` + stats tcpip.TransportEndpointStats // The following fields are used to manage the receive queue. rcvMu sync.Mutex `state:"nosave"` diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 3040a445b..bfef75da7 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -70,7 +70,7 @@ type endpoint struct { associated bool net network.Endpoint - stats tcpip.TransportEndpointStats `state:"nosave"` + stats tcpip.TransportEndpointStats ops tcpip.SocketOptions // The following fields are used to manage the receive queue and are diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 5148fe157..20958d882 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -80,9 +80,10 @@ go_library( go_test( name = "tcp_x_test", - size = "medium", + size = "large", srcs = [ "dual_stack_test.go", + "rcv_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", "tcp_rack_test.go", @@ -114,16 +115,6 @@ go_test( ) go_test( - name = "rcv_test", - size = "small", - srcs = ["rcv_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( name = "tcp_test", size = "small", srcs = [ diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index ff0a5df9c..7115d0a12 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -193,14 +193,6 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } -func (l *listenContext) useSynCookies() bool { - var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies - if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { - panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) - } - return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) -} - // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { @@ -277,7 +269,7 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. - l.listenEP.propagateInheritableOptionsLocked(ep) + l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce if !ep.reserveTupleLocked() { ep.mu.Unlock() @@ -367,7 +359,6 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } -// Precondition: h.ep.mu must be held. // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep @@ -384,7 +375,7 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // cleanupCompletedHandshake transfers any state from the completed handshake to // the new endpoint. // -// Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep if l.listenEP != nil { @@ -404,7 +395,8 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. // -// Precondition: e.mu and n.mu must be held. +// +checklocks:e.mu +// +checklocks:n.mu func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout n.portFlags = e.portFlags @@ -415,9 +407,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // reserveTupleLocked reserves an accepted endpoint's tuple. // -// Preconditions: -// * propagateInheritableOptionsLocked has been called. -// * e.mu is held. +// Precondition: e.propagateInheritableOptionsLocked has been called. +// +// +checklocks:e.mu func (e *endpoint) reserveTupleLocked() bool { dest := tcpip.FullAddress{ Addr: e.TransportEndpointInfo.ID.RemoteAddress, @@ -459,7 +451,7 @@ func (e *endpoint) notifyAborted() { // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error { defer s.decRef() @@ -552,7 +544,7 @@ func (a *accepted) acceptQueueIsFullLocked() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed @@ -579,8 +571,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } + alwaysUseSynCookies := func() bool { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + return bool(alwaysUseSynCookies) + }() + opts := parseSynSegmentOptions(s) - if !ctx.useSynCookies() { + if !alwaysUseSynCookies && !e.synRcvdBacklogFull() { s.incRef() atomic.AddInt32(&e.synRcvdCount, 1) return e.handleSynSegment(ctx, s, opts) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a3002abf3..407ab2664 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2066,7 +2066,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber) e.UnlockUser() if err != nil { return err diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index 8a026ec46..e47a07030 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rcv_test +package tcp_test import ( "testing" diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index b355fa7eb..049957b81 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -60,9 +60,8 @@ type endpoint struct { waiterQueue *waiter.Queue uniqueID uint64 net network.Endpoint - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats tcpip.TransportEndpointStats `state:"nosave"` - ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. |