summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/lisafs/client_file.go101
-rw-r--r--pkg/lisafs/testsuite/testsuite.go4
-rw-r--r--pkg/ring0/pagetables/pagetables.go9
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go14
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go14
-rw-r--r--pkg/sentry/platform/kvm/BUILD22
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go3
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s27
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s34
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm_safecopy_test.go104
-rw-r--r--pkg/sentry/platform/kvm/machine.go136
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go25
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go12
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go120
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go12
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go43
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go3
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.go4
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s35
-rw-r--r--pkg/sentry/seccheck/BUILD4
-rw-r--r--pkg/sentry/seccheck/execve.go65
-rw-r--r--pkg/sentry/seccheck/exit.go57
-rw-r--r--pkg/sentry/seccheck/seccheck.go26
-rw-r--r--pkg/sentry/socket/netfilter/targets.go2
-rw-r--r--pkg/sync/atomicptr/generic_atomicptr_unsafe.go2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go16
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go16
-rw-r--r--pkg/tcpip/stack/conntrack.go110
-rw-r--r--pkg/tcpip/stack/iptables.go104
-rw-r--r--pkg/tcpip/stack/iptables_targets.go64
-rw-r--r--pkg/tcpip/stack/iptables_types.go2
-rw-r--r--pkg/tcpip/stack/packet_buffer.go23
-rw-r--r--pkg/tcpip/tcpip.go16
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go288
-rw-r--r--pkg/tcpip/tests/utils/utils.go36
-rw-r--r--pkg/tcpip/transport/icmp/BUILD2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go425
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go35
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD1
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/BUILD13
-rw-r--r--pkg/tcpip/transport/tcp/accept.go36
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go5
48 files changed, 1523 insertions, 591 deletions
diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go
index 0f8788f3b..170c15705 100644
--- a/pkg/lisafs/client_file.go
+++ b/pkg/lisafs/client_file.go
@@ -15,6 +15,8 @@
package lisafs
import (
+ "fmt"
+
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -121,41 +123,92 @@ func (f *ClientFD) Sync(ctx context.Context) error {
return err
}
+// chunkify applies fn to buf in chunks based on chunkSize.
+func chunkify(chunkSize uint64, buf []byte, fn func([]byte, uint64) (uint64, error)) (uint64, error) {
+ toProcess := uint64(len(buf))
+ var (
+ totalProcessed uint64
+ curProcessed uint64
+ off uint64
+ err error
+ )
+ for {
+ if totalProcessed == toProcess {
+ return totalProcessed, nil
+ }
+
+ if totalProcessed+chunkSize > toProcess {
+ curProcessed, err = fn(buf[totalProcessed:], off)
+ } else {
+ curProcessed, err = fn(buf[totalProcessed:totalProcessed+chunkSize], off)
+ }
+ totalProcessed += curProcessed
+ off += curProcessed
+
+ if err != nil {
+ return totalProcessed, err
+ }
+
+ // Return partial result immediately.
+ if curProcessed < chunkSize {
+ return totalProcessed, nil
+ }
+
+ // If we received more bytes than we ever requested, this is a problem.
+ if totalProcessed > toProcess {
+ panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", totalProcessed, toProcess))
+ }
+ }
+}
+
// Read makes the PRead RPC.
func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) {
- req := PReadReq{
- Offset: offset,
- FD: f.fd,
- Count: uint32(len(dst)),
- }
+ var resp PReadResp
+ // maxDataReadSize represents the maximum amount of data we can read at once
+ // (maximum message size - metadata size present in resp). Uninitialized
+ // resp.SizeBytes() correctly returns the metadata size only (since the read
+ // buffer is empty).
+ maxDataReadSize := uint64(f.client.maxMessageSize) - uint64(resp.SizeBytes())
+ return chunkify(maxDataReadSize, dst, func(buf []byte, curOff uint64) (uint64, error) {
+ req := PReadReq{
+ Offset: offset + curOff,
+ FD: f.fd,
+ Count: uint32(len(buf)),
+ }
- resp := PReadResp{
// This will be unmarshalled into. Already set Buf so that we don't need to
// allocate a temporary buffer during unmarshalling.
// PReadResp.UnmarshalBytes expects this to be set.
- Buf: dst,
- }
-
- ctx.UninterruptibleSleepStart(false)
- err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
- ctx.UninterruptibleSleepFinish(false)
- return uint64(resp.NumBytes), err
+ resp.Buf = buf
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return uint64(resp.NumBytes), err
+ })
}
// Write makes the PWrite RPC.
func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) {
- req := PWriteReq{
- Offset: primitive.Uint64(offset),
- FD: f.fd,
- NumBytes: primitive.Uint32(len(src)),
- Buf: src,
- }
+ var req PWriteReq
+ // maxDataWriteSize represents the maximum amount of data we can write at
+ // once (maximum message size - metadata size present in req). Uninitialized
+ // req.SizeBytes() correctly returns the metadata size only (since the write
+ // buffer is empty).
+ maxDataWriteSize := uint64(f.client.maxMessageSize) - uint64(req.SizeBytes())
+ return chunkify(maxDataWriteSize, src, func(buf []byte, curOff uint64) (uint64, error) {
+ req = PWriteReq{
+ Offset: primitive.Uint64(offset + curOff),
+ FD: f.fd,
+ NumBytes: primitive.Uint32(len(buf)),
+ Buf: buf,
+ }
- var resp PWriteResp
- ctx.UninterruptibleSleepStart(false)
- err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
- ctx.UninterruptibleSleepFinish(false)
- return resp.Count, err
+ var resp PWriteResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Count, err
+ })
}
// MkdirAt makes the MkdirAt RPC.
diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go
index 476ff76a5..5fc7c364d 100644
--- a/pkg/lisafs/testsuite/testsuite.go
+++ b/pkg/lisafs/testsuite/testsuite.go
@@ -330,8 +330,8 @@ func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root li
defer closeFD(ctx, t, fd)
defer unix.Close(hostFD)
- // Test Read/Write RPCs.
- data := make([]byte, 100)
+ // Test Read/Write RPCs with 2MB of data to test IO in chunks.
+ data := make([]byte, 1<<21)
rand.Read(data)
if err := writeFD(ctx, t, fd, 0, data); err != nil {
t.Fatalf("write failed: %v", err)
diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go
index 9dac53c80..3f17fba49 100644
--- a/pkg/ring0/pagetables/pagetables.go
+++ b/pkg/ring0/pagetables/pagetables.go
@@ -322,12 +322,3 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc
func (p *PageTables) MarkReadOnlyShared() {
p.readOnlyShared = true
}
-
-// PrefaultRootTable touches the root table page to be sure that its physical
-// pages are mapped.
-//
-//go:nosplit
-//go:noinline
-func (p *PageTables) PrefaultRootTable() PTE {
- return p.root[0]
-}
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index f322d2747..7fcb2d26b 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -84,6 +84,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs.MaxCachedDentries = maxCachedDentries
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+ k := kernel.KernelFromContext(ctx)
+ fsDirChildren := make(map[string]kernfs.Inode)
+ // Create an empty directory to serve as the mount point for cgroupfs when
+ // cgroups are available. This emulates Linux behaviour, see
+ // kernel/cgroup.c:cgroup_init(). Note that in Linux, userspace (typically
+ // the init process) is ultimately responsible for actually mounting
+ // cgroupfs, but the kernel creates the mountpoint. For the sentry, the
+ // launcher mounts cgroupfs.
+ if k.CgroupRegistry() != nil {
+ fsDirChildren["cgroup"] = fs.newDir(ctx, creds, defaultSysDirMode, nil)
+ }
+
root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
"block": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"bus": fs.newDir(ctx, creds, defaultSysDirMode, nil),
@@ -97,7 +109,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}),
}),
"firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil),
- "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "fs": fs.newDir(ctx, creds, defaultSysDirMode, fsDirChildren),
"kernel": kernelDir(ctx, fs, creds),
"module": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"power": fs.newDir(ctx, creds, defaultSysDirMode, nil),
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
index 0a0d914cc..0c46a3a13 100644
--- a/pkg/sentry/fsimpl/sys/sys_test.go
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -87,3 +87,17 @@ func TestSysRootContainsExpectedEntries(t *testing.T) {
"power": linux.DT_DIR,
})
}
+
+func TestCgroupMountpointExists(t *testing.T) {
+ // Note: The mountpoint is only created if cgroups are available. This is
+ // the VFS2 implementation of sysfs and the test runs with VFS2 enabled, so
+ // we expect to see the mount point unconditionally.
+ s := newTestSystem(t)
+ defer s.Destroy()
+ pop := s.PathOpAtRoot("/fs")
+ s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{
+ "cgroup": linux.DT_DIR,
+ })
+ pop = s.PathOpAtRoot("/fs/cgroup")
+ s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ /*empty*/ })
+}
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 8a490b3de..a26f54269 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -1,13 +1,26 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "atomicptr_machine",
+ out = "atomicptr_machine_unsafe.go",
+ package = "kvm",
+ prefix = "machine",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
+ types = {
+ "Value": "machine",
+ },
+)
+
go_library(
name = "kvm",
srcs = [
"address_space.go",
"address_space_amd64.go",
"address_space_arm64.go",
+ "atomicptr_machine_unsafe.go",
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
@@ -69,10 +82,17 @@ go_test(
"kvm_amd64_test.go",
"kvm_amd64_test.s",
"kvm_arm64_test.go",
+ "kvm_safecopy_test.go",
"kvm_test.go",
"virtual_map_test.go",
],
library = ":kvm",
+ # FIXME(gvisor.dev/issue/3374): Not working with all build systems.
+ nogo = False,
+ # cgo has to be disabled. We have seen libc that blocks all signals and
+ # calls mmap from pthread_create, but we use SIGSYS to trap mmap system
+ # calls.
+ pure = True,
tags = [
"manual",
"nogotsan",
@@ -81,8 +101,10 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/hostarch",
+ "//pkg/memutil",
"//pkg/ring0",
"//pkg/ring0/pagetables",
+ "//pkg/safecopy",
"//pkg/sentry/arch",
"//pkg/sentry/arch/fpu",
"//pkg/sentry/platform",
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index bb9967b9f..826997e77 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -61,6 +61,9 @@ var (
// This is called by bluepillHandler.
savedHandler uintptr
+ // savedSigsysHandler is a pointer to the previos handler of the SIGSYS signals.
+ savedSigsysHandler uintptr
+
// dieTrampolineAddr is the address of dieTrampoline.
dieTrampolineAddr uintptr
)
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index c2a1dca11..5d8358f64 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -32,6 +32,8 @@
// This is checked as the source of the fault.
#define CLI $0xfa
+#define SYS_MMAP 9
+
// See bluepill.go.
TEXT ·bluepill(SB),NOSPLIT,$0
begin:
@@ -95,6 +97,31 @@ TEXT ·addrOfSighandler(SB), $0-8
MOVQ AX, ret+0(FP)
RET
+TEXT ·sigsysHandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel.
+ MOVQ $1, CX
+ CMPL CX, 0x8(SI)
+ JNE fallback
+
+ MOVL CONTEXT_RAX(DX), CX
+ CMPL CX, $SYS_MMAP
+ JNE fallback
+ PUSHQ DX // First argument (context).
+ CALL ·seccompMmapHandler(SB) // Call the handler.
+ POPQ DX // Discard the argument.
+ RET
+fallback:
+ // Jump to the previous signal handler.
+ XORQ CX, CX
+ MOVQ ·savedSigsysHandler(SB), AX
+ JMP AX
+
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSigsysHandler(SB), $0-8
+ MOVQ $·sigsysHandler(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
PUSHQ BX // First argument (vCPU).
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
index 308f2a951..9690e3772 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.s
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -29,9 +29,12 @@
// Only limited use of the context is done in the assembly stub below, most is
// done in the Go handlers.
#define SIGINFO_SIGNO 0x0
+#define SIGINFO_CODE 0x8
#define CONTEXT_PC 0x1B8
#define CONTEXT_R0 0xB8
+#define SYS_MMAP 222
+
// getTLS returns the value of TPIDR_EL0 register.
TEXT ·getTLS(SB),NOSPLIT,$0-8
MRS TPIDR_EL0, R1
@@ -98,6 +101,37 @@ TEXT ·addrOfSighandler(SB), $0-8
MOVD R0, ret+0(FP)
RET
+// The arguments are the following:
+//
+// R0 - The signal number.
+// R1 - Pointer to siginfo_t structure.
+// R2 - Pointer to ucontext structure.
+//
+TEXT ·sigsysHandler(SB),NOSPLIT,$0
+ // si_code should be SYS_SECCOMP.
+ MOVD SIGINFO_CODE(R1), R7
+ CMPW $1, R7
+ BNE fallback
+
+ CMPW $SYS_MMAP, R8
+ BNE fallback
+
+ MOVD R2, 8(RSP)
+ BL ·seccompMmapHandler(SB) // Call the handler.
+
+ RET
+
+fallback:
+ // Jump to the previous signal handler.
+ MOVD ·savedHandler(SB), R7
+ B (R7)
+
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSigsysHandler(SB), $0-8
+ MOVD $·sigsysHandler(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
// R0: Fake the old PC as caller
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 0f0c1e73b..e38ca05c0 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -193,36 +193,8 @@ func bluepillHandler(context unsafe.Pointer) {
return
}
- // Increment the fault count.
- atomic.AddUint32(&c.faults, 1)
-
- // For MMIO, the physical address is the first data item.
- physical = uintptr(c.runData.data[0])
- virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
- if !ok {
- c.die(bluepillArchContext(context), "invalid physical address")
- return
- }
-
- // We now need to fill in the data appropriately. KVM
- // expects us to provide the result of the given MMIO
- // operation in the runData struct. This is safe
- // because, if a fault occurs here, the same fault
- // would have occurred in guest mode. The kernel should
- // not create invalid page table mappings.
- data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1]))
- length := (uintptr)((uint32)(c.runData.data[2]))
- write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0
- for i := uintptr(0); i < length; i++ {
- b := bytePtr(uintptr(virtual) + i)
- if write {
- // Write to the given address.
- *b = data[i]
- } else {
- // Read from the given address.
- data[i] = *b
- }
- }
+ c.die(bluepillArchContext(context), "exit_mmio")
+ return
case _KVM_EXIT_IRQ_WINDOW_OPEN:
bluepillStopGuest(c)
case _KVM_EXIT_SHUTDOWN:
diff --git a/pkg/sentry/platform/kvm/kvm_safecopy_test.go b/pkg/sentry/platform/kvm/kvm_safecopy_test.go
new file mode 100644
index 000000000..9a87c9e6f
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_safecopy_test.go
@@ -0,0 +1,104 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// FIXME(gvisor.dev/issue//6629): These tests don't pass on ARM64.
+//
+//go:build amd64
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "os"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/safecopy"
+)
+
+func testSafecopy(t *testing.T, mapSize uintptr, fileSize uintptr, testFunc func(t *testing.T, c *vCPU, addr uintptr)) {
+ memfd, err := memutil.CreateMemFD(fmt.Sprintf("kvm_test_%d", os.Getpid()), 0)
+ if err != nil {
+ t.Errorf("error creating memfd: %v", err)
+ }
+
+ memfile := os.NewFile(uintptr(memfd), "kvm_test")
+ memfile.Truncate(int64(fileSize))
+ kvmTest(t, nil, func(c *vCPU) bool {
+ const n = 10
+ mappings := make([]uintptr, n)
+ defer func() {
+ for i := 0; i < n && mappings[i] != 0; i++ {
+ unix.RawSyscall(
+ unix.SYS_MUNMAP,
+ mappings[i], mapSize, 0)
+ }
+ }()
+ for i := 0; i < n; i++ {
+ addr, _, errno := unix.RawSyscall6(
+ unix.SYS_MMAP,
+ 0,
+ mapSize,
+ unix.PROT_READ|unix.PROT_WRITE,
+ unix.MAP_SHARED|unix.MAP_FILE,
+ uintptr(memfile.Fd()),
+ 0)
+ if errno != 0 {
+ t.Errorf("error mapping file: %v", errno)
+ }
+ mappings[i] = addr
+ testFunc(t, c, addr)
+ }
+ return false
+ })
+}
+
+func TestSafecopySigbus(t *testing.T) {
+ mapSize := uintptr(faultBlockSize)
+ fileSize := mapSize - hostarch.PageSize
+ buf := make([]byte, hostarch.PageSize)
+ testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) {
+ want := safecopy.BusError{addr + fileSize}
+ bluepill(c)
+ _, err := safecopy.CopyIn(buf, unsafe.Pointer(addr+fileSize))
+ if err != want {
+ t.Errorf("expected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func TestSafecopy(t *testing.T) {
+ mapSize := uintptr(faultBlockSize)
+ fileSize := mapSize
+ testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) {
+ want := uint32(0x12345678)
+ bluepill(c)
+ _, err := safecopy.SwapUint32(unsafe.Pointer(addr+fileSize-8), want)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ bluepill(c)
+ val, err := safecopy.LoadUint32(unsafe.Pointer(addr + fileSize - 8))
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if val != want {
+ t.Errorf("incorrect value: got %x, want %x", val, want)
+ }
+ })
+}
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index d67563958..dcf34015d 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -17,15 +17,19 @@ package kvm
import (
"fmt"
"runtime"
+ gosync "sync"
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/seccomp"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -35,6 +39,9 @@ type machine struct {
// fd is the vm fd.
fd int
+ // machinePoolIndex is the index in the machinePool array.
+ machinePoolIndex uint32
+
// nextSlot is the next slot for setMemoryRegion.
//
// This must be accessed atomically. If nextSlot is ^uint32(0), then
@@ -192,6 +199,10 @@ func (m *machine) newVCPU() *vCPU {
return c // Done.
}
+// readOnlyGuestRegions contains regions that have to be mapped read-only into
+// the guest physical address space. Right now, it is used on arm64 only.
+var readOnlyGuestRegions []region
+
// newMachine returns a new VM context.
func newMachine(vm int) (*machine, error) {
// Create the machine.
@@ -227,6 +238,10 @@ func newMachine(vm int) (*machine, error) {
m.upperSharedPageTables.MarkReadOnlyShared()
m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress)
+ // Install seccomp rules to trap runtime mmap system calls. They will
+ // be handled by seccompMmapHandler.
+ seccompMmapRules(m)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -241,32 +256,11 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
- var physicalRegionsReadOnly []physicalRegion
- var physicalRegionsAvailable []physicalRegion
-
- physicalRegionsReadOnly = rdonlyRegionsForSetMem()
- physicalRegionsAvailable = availableRegionsForSetMem()
-
- // Map all read-only regions.
- for _, r := range physicalRegionsReadOnly {
- m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY)
- }
-
// Ensure that the currently mapped virtual regions are actually
// available in the VM. Note that this doesn't guarantee no future
// faults, however it should guarantee that everything is available to
// ensure successful vCPU entry.
- applyVirtualRegions(func(vr virtualRegion) {
- if excludeVirtualRegion(vr) {
- return // skip region.
- }
-
- for _, r := range physicalRegionsReadOnly {
- if vr.virtual == r.virtual {
- return
- }
- }
-
+ mapRegion := func(vr region, flags uint32) {
for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
physical, length, ok := translateToPhysical(virtual)
if !ok {
@@ -280,9 +274,32 @@ func newMachine(vm int) (*machine, error) {
}
// Ensure the physical range is mapped.
- m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE)
+ m.mapPhysical(physical, length, physicalRegions, flags)
virtual += length
}
+ }
+
+ for _, vr := range readOnlyGuestRegions {
+ mapRegion(vr, _KVM_MEM_READONLY)
+ }
+
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return // skip region.
+ }
+ for _, r := range readOnlyGuestRegions {
+ if vr.virtual == r.virtual {
+ return
+ }
+ }
+ // Take into account that the stack can grow down.
+ if vr.filename == "[stack]" {
+ vr.virtual -= 1 << 20
+ vr.length += 1 << 20
+ }
+
+ mapRegion(vr.region, 0)
+
})
// Initialize architecture state.
@@ -352,6 +369,10 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg
func (m *machine) Destroy() {
runtime.SetFinalizer(m, nil)
+ machinePoolMu.Lock()
+ machinePool[m.machinePoolIndex].Store(nil)
+ machinePoolMu.Unlock()
+
// Destroy vCPUs.
for _, c := range m.vCPUsByID {
if c == nil {
@@ -683,3 +704,72 @@ func (c *vCPU) setSystemTimeLegacy() error {
}
}
}
+
+const machinePoolSize = 16
+
+// machinePool is enumerated from the seccompMmapHandler signal handler
+var (
+ machinePool [machinePoolSize]machineAtomicPtr
+ machinePoolLen uint32
+ machinePoolMu sync.Mutex
+ seccompMmapRulesOnce gosync.Once
+)
+
+func sigsysHandler()
+func addrOfSigsysHandler() uintptr
+
+// seccompMmapRules adds seccomp rules to trap mmap system calls that will be
+// handled in seccompMmapHandler.
+func seccompMmapRules(m *machine) {
+ seccompMmapRulesOnce.Do(func() {
+ // Install the handler.
+ if err := safecopy.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
+ }
+ rules := []seccomp.RuleSet{}
+ rules = append(rules, []seccomp.RuleSet{
+ // Trap mmap system calls and handle them in sigsysGoHandler
+ {
+ Rules: seccomp.SyscallRules{
+ unix.SYS_MMAP: {
+ {
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ /* MAP_DENYWRITE is ignored and used only for filtering. */
+ seccomp.MaskedEqual(unix.MAP_DENYWRITE, 0),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_TRAP,
+ },
+ }...)
+ instrs, err := seccomp.BuildProgram(rules, linux.SECCOMP_RET_ALLOW, linux.SECCOMP_RET_ALLOW)
+ if err != nil {
+ panic(fmt.Sprintf("failed to build rules: %v", err))
+ }
+ // Perform the actual installation.
+ if err := seccomp.SetFilter(instrs); err != nil {
+ panic(fmt.Sprintf("failed to set filter: %v", err))
+ }
+ })
+
+ machinePoolMu.Lock()
+ n := atomic.LoadUint32(&machinePoolLen)
+ i := uint32(0)
+ for ; i < n; i++ {
+ if machinePool[i].Load() == nil {
+ break
+ }
+ }
+ if i == n {
+ if i == machinePoolSize {
+ machinePoolMu.Unlock()
+ panic("machinePool is full")
+ }
+ atomic.AddUint32(&machinePoolLen, 1)
+ }
+ machinePool[i].Store(m)
+ m.machinePoolIndex = i
+ machinePoolMu.Unlock()
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index a96634381..ab1e036b7 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -309,22 +309,6 @@ func loadByte(ptr *byte) byte {
return *ptr
}
-// prefaultFloatingPointState touches each page of the floating point state to
-// be sure that its physical pages are mapped.
-//
-// Otherwise the kernel can trigger KVM_EXIT_MMIO and an instruction that
-// triggered a fault will be emulated by the kvm kernel code, but it can't
-// emulate instructions like xsave and xrstor.
-//
-//go:nosplit
-func prefaultFloatingPointState(data *fpu.State) {
- size := len(*data)
- for i := 0; i < size; i += hostarch.PageSize {
- loadByte(&(*data)[i])
- }
- loadByte(&(*data)[size-1])
-}
-
// SwitchToUser unpacks architectural-details.
func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
@@ -355,11 +339,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo)
// allocations occur.
entersyscall()
bluepill(c)
- // The root table physical page has to be mapped to not fault in iret
- // or sysret after switching into a user address space. sysret and
- // iret are in the upper half that is global and already mapped.
- switchOpts.PageTables.PrefaultRootTable()
- prefaultFloatingPointState(switchOpts.FloatingPointState)
vector = c.CPU.SwitchToUser(switchOpts)
exitsyscall()
@@ -522,3 +501,7 @@ func (m *machine) getNewVCPU() *vCPU {
}
return nil
}
+
+func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion {
+ return physicalRegions
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index de798bb2c..fbacea9ad 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -161,3 +161,15 @@ func (c *vCPU) getSystemRegisters(sregs *systemRegs) unix.Errno {
}
return 0
}
+
+//go:nosplit
+func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) {
+ ctx := bluepillArchContext(context)
+
+ // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters.
+ addr, _, e := unix.RawSyscall6(uintptr(ctx.Rax), uintptr(ctx.Rdi), uintptr(ctx.Rsi),
+ uintptr(ctx.Rdx), uintptr(ctx.R10)|unix.MAP_DENYWRITE, uintptr(ctx.R8), uintptr(ctx.R9))
+ ctx.Rax = uint64(addr)
+
+ return addr, uintptr(ctx.Rsi), e
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 7937a8481..08d98c479 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -110,18 +110,128 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
return phyRegions
}
+// archPhysicalRegions fills readOnlyGuestRegions and allocates separate
+// physical regions form them.
+func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion {
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return // skip region.
+ }
+ if !vr.accessType.Write {
+ readOnlyGuestRegions = append(readOnlyGuestRegions, vr.region)
+ }
+ })
+
+ rdRegions := readOnlyGuestRegions[:]
+
+ // Add an unreachable region.
+ rdRegions = append(rdRegions, region{
+ virtual: 0xffffffffffffffff,
+ length: 0,
+ })
+
+ var regions []physicalRegion
+ addValidRegion := func(r *physicalRegion, virtual, length uintptr) {
+ if length == 0 {
+ return
+ }
+ regions = append(regions, physicalRegion{
+ region: region{
+ virtual: virtual,
+ length: length,
+ },
+ physical: r.physical + (virtual - r.virtual),
+ })
+ }
+ i := 0
+ for _, pr := range physicalRegions {
+ start := pr.virtual
+ end := pr.virtual + pr.length
+ for start < end {
+ rdRegion := rdRegions[i]
+ rdStart := rdRegion.virtual
+ rdEnd := rdRegion.virtual + rdRegion.length
+ if rdEnd <= start {
+ i++
+ continue
+ }
+ if rdStart > start {
+ newEnd := rdStart
+ if end < rdStart {
+ newEnd = end
+ }
+ addValidRegion(&pr, start, newEnd-start)
+ start = rdStart
+ continue
+ }
+ if rdEnd < end {
+ addValidRegion(&pr, start, rdEnd-start)
+ start = rdEnd
+ continue
+ }
+ addValidRegion(&pr, start, end-start)
+ start = end
+ }
+ }
+
+ return regions
+}
+
// Get all available physicalRegions.
-func availableRegionsForSetMem() (phyRegions []physicalRegion) {
- var excludeRegions []region
+func availableRegionsForSetMem() []physicalRegion {
+ var excludedRegions []region
applyVirtualRegions(func(vr virtualRegion) {
if !vr.accessType.Write {
- excludeRegions = append(excludeRegions, vr.region)
+ excludedRegions = append(excludedRegions, vr.region)
}
})
- phyRegions = computePhysicalRegions(excludeRegions)
+ // Add an unreachable region.
+ excludedRegions = append(excludedRegions, region{
+ virtual: 0xffffffffffffffff,
+ length: 0,
+ })
- return phyRegions
+ var regions []physicalRegion
+ addValidRegion := func(r *physicalRegion, virtual, length uintptr) {
+ if length == 0 {
+ return
+ }
+ regions = append(regions, physicalRegion{
+ region: region{
+ virtual: virtual,
+ length: length,
+ },
+ physical: r.physical + (virtual - r.virtual),
+ })
+ }
+ i := 0
+ for _, pr := range physicalRegions {
+ start := pr.virtual
+ end := pr.virtual + pr.length
+ for start < end {
+ er := excludedRegions[i]
+ excludeEnd := er.virtual + er.length
+ excludeStart := er.virtual
+ if excludeEnd < start {
+ i++
+ continue
+ }
+ if excludeStart < start {
+ start = excludeEnd
+ i++
+ continue
+ }
+ rend := excludeStart
+ if rend > end {
+ rend = end
+ }
+ addValidRegion(&pr, start, rend-start)
+ start = excludeEnd
+ }
+ }
+
+ return regions
}
// nonCanonical generates a canonical address return.
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 1a4a9ce7d..7e8e19dcb 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -333,3 +333,15 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo)
}
}
+
+//go:nosplit
+func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) {
+ ctx := bluepillArchContext(context)
+
+ // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters.
+ addr, _, e := unix.RawSyscall6(uintptr(ctx.Regs[8]), uintptr(ctx.Regs[0]), uintptr(ctx.Regs[1]),
+ uintptr(ctx.Regs[2]), uintptr(ctx.Regs[3])|unix.MAP_DENYWRITE, uintptr(ctx.Regs[4]), uintptr(ctx.Regs[5]))
+ ctx.Regs[0] = uint64(addr)
+
+ return addr, uintptr(ctx.Regs[1]), e
+}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index cc3a1253b..cf3a4e7c9 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -171,3 +171,46 @@ func (c *vCPU) setSignalMask() error {
return nil
}
+
+// seccompMmapHandler is a signal handler for runtime mmap system calls
+// that are trapped by seccomp.
+//
+// It executes the mmap syscall with specified arguments and maps a new region
+// to the guest.
+//
+//go:nosplit
+func seccompMmapHandler(context unsafe.Pointer) {
+ addr, length, errno := seccompMmapSyscall(context)
+ if errno != 0 {
+ return
+ }
+
+ for i := uint32(0); i < atomic.LoadUint32(&machinePoolLen); i++ {
+ m := machinePool[i].Load()
+ if m == nil {
+ continue
+ }
+
+ // Map the new region to the guest.
+ vr := region{
+ virtual: addr,
+ length: length,
+ }
+ for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
+ physical, length, ok := translateToPhysical(virtual)
+ if !ok {
+ // This must be an invalid region that was
+ // knocked out by creation of the physical map.
+ return
+ }
+ if virtual+length > vr.virtual+vr.length {
+ // Cap the length to the end of the area.
+ length = vr.virtual + vr.length - virtual
+ }
+
+ // Ensure the physical range is mapped.
+ m.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE)
+ virtual += length
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
index d812e6c26..9864d1258 100644
--- a/pkg/sentry/platform/kvm/physical_map.go
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -168,6 +168,9 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica
}
addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd)
+ // Do arch-specific actions on physical regions.
+ physicalRegions = archPhysicalRegions(physicalRegions)
+
// Dump our all physical regions.
for _, r := range physicalRegions {
log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)",
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
index 6d0ba8252..346a10043 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
@@ -30,8 +30,8 @@ import (
func TLSWorks() bool
// SetTestTarget sets the rip appropriately.
-func SetTestTarget(regs *arch.Registers, fn func()) {
- regs.Pc = uint64(reflect.ValueOf(fn).Pointer())
+func SetTestTarget(regs *arch.Registers, fn uintptr) {
+ regs.Pc = uint64(fn)
}
// SetTouchTarget sets rax appropriately.
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
index 7348c29a5..42876245a 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -28,6 +28,11 @@ TEXT ·Getpid(SB),NOSPLIT,$0
SVC
RET
+TEXT ·AddrOfGetpid(SB),NOSPLIT,$0-8
+ MOVD $·Getpid(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·Touch(SB),NOSPLIT,$0
start:
MOVD 0(R8), R1
@@ -35,21 +40,41 @@ start:
SVC
B start
+TEXT ·AddrOfTouch(SB),NOSPLIT,$0-8
+ MOVD $·Touch(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·HaltLoop(SB),NOSPLIT,$0
start:
HLT
B start
+TEXT ·AddOfHaltLoop(SB),NOSPLIT,$0-8
+ MOVD $·HaltLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// This function simulates a loop of syscall.
TEXT ·SyscallLoop(SB),NOSPLIT,$0
start:
SVC
B start
+TEXT ·AddrOfSyscallLoop(SB),NOSPLIT,$0-8
+ MOVD $·SyscallLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·SpinLoop(SB),NOSPLIT,$0
start:
B start
+TEXT ·AddrOfSpinLoop(SB),NOSPLIT,$0-8
+ MOVD $·SpinLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·TLSWorks(SB),NOSPLIT,$0-8
NO_LOCAL_POINTERS
MOVD $0x6789, R5
@@ -125,6 +150,11 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
SVC
RET // never reached
+TEXT ·AddrOfTwiddleRegsSyscall(SB),NOSPLIT,$0-8
+ MOVD $·TwiddleRegsSyscall(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
TWIDDLE_REGS()
MSR R10, TPIDR_EL0
@@ -132,3 +162,8 @@ TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
// Branch to Register branches unconditionally to an address in <Rn>.
JMP (R6) // <=> br x6, must fault
RET // never reached
+
+TEXT ·AddrOfTwiddleRegsFault(SB),NOSPLIT,$0-8
+ MOVD $·TwiddleRegsFault(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD
index 943fa180d..35feb969f 100644
--- a/pkg/sentry/seccheck/BUILD
+++ b/pkg/sentry/seccheck/BUILD
@@ -8,6 +8,8 @@ go_fieldenum(
name = "seccheck_fieldenum",
srcs = [
"clone.go",
+ "execve.go",
+ "exit.go",
"task.go",
],
out = "seccheck_fieldenum.go",
@@ -29,6 +31,8 @@ go_library(
name = "seccheck",
srcs = [
"clone.go",
+ "execve.go",
+ "exit.go",
"seccheck.go",
"seccheck_fieldenum.go",
"seqatomic_checkerslice_unsafe.go",
diff --git a/pkg/sentry/seccheck/execve.go b/pkg/sentry/seccheck/execve.go
new file mode 100644
index 000000000..f36e0730e
--- /dev/null
+++ b/pkg/sentry/seccheck/execve.go
@@ -0,0 +1,65 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// ExecveInfo contains information used by the Execve checkpoint.
+//
+// +fieldenum Execve
+type ExecveInfo struct {
+ // Invoker identifies the invoking thread.
+ Invoker TaskInfo
+
+ // Credentials are the invoking thread's credentials.
+ Credentials *auth.Credentials
+
+ // BinaryPath is a path to the executable binary file being switched to in
+ // the mount namespace in which it was opened.
+ BinaryPath string
+
+ // Argv is the new process image's argument vector.
+ Argv []string
+
+ // Env is the new process image's environment variables.
+ Env []string
+
+ // BinaryMode is the executable binary file's mode.
+ BinaryMode uint16
+
+ // BinarySHA256 is the SHA-256 hash of the executable binary file.
+ //
+ // Note that this requires reading the entire file into memory, which is
+ // likely to be extremely slow.
+ BinarySHA256 [32]byte
+}
+
+// ExecveReq returns fields required by the Execve checkpoint.
+func (s *state) ExecveReq() ExecveFieldSet {
+ return s.execveReq.Load()
+}
+
+// Execve is called at the Execve checkpoint.
+func (s *state) Execve(ctx context.Context, mask ExecveFieldSet, info *ExecveInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.Execve(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/exit.go b/pkg/sentry/seccheck/exit.go
new file mode 100644
index 000000000..69cb6911c
--- /dev/null
+++ b/pkg/sentry/seccheck/exit.go
@@ -0,0 +1,57 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// ExitNotifyParentInfo contains information used by the ExitNotifyParent
+// checkpoint.
+//
+// +fieldenum ExitNotifyParent
+type ExitNotifyParentInfo struct {
+ // Exiter identifies the exiting thread. Note that by the checkpoint's
+ // definition, Exiter.ThreadID == Exiter.ThreadGroupID and
+ // Exiter.ThreadStartTime == Exiter.ThreadGroupStartTime, so requesting
+ // ThreadGroup* fields is redundant.
+ Exiter TaskInfo
+
+ // ExitStatus is the exiting thread group's exit status, as reported
+ // by wait*().
+ ExitStatus linux.WaitStatus
+}
+
+// ExitNotifyParentReq returns fields required by the ExitNotifyParent
+// checkpoint.
+func (s *state) ExitNotifyParentReq() ExitNotifyParentFieldSet {
+ return s.exitNotifyParentReq.Load()
+}
+
+// ExitNotifyParent is called at the ExitNotifyParent checkpoint.
+//
+// The ExitNotifyParent checkpoint occurs when a zombied thread group leader,
+// not waiting for exit acknowledgement from a non-parent ptracer, becomes the
+// last non-dead thread in its thread group and notifies its parent of its
+// exiting.
+func (s *state) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info *ExitNotifyParentInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.ExitNotifyParent(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go
index b6c9d44ce..e13274096 100644
--- a/pkg/sentry/seccheck/seccheck.go
+++ b/pkg/sentry/seccheck/seccheck.go
@@ -29,6 +29,8 @@ type Point uint
// PointX represents the checkpoint X.
const (
PointClone Point = iota
+ PointExecve
+ PointExitNotifyParent
// Add new Points above this line.
pointLength
@@ -47,6 +49,8 @@ const (
// registered concurrently with invocations of checkpoints).
type Checker interface {
Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error
+ Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error
+ ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error
}
// CheckerDefaults may be embedded by implementations of Checker to obtain
@@ -58,6 +62,16 @@ func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info Clone
return nil
}
+// Execve implements Checker.Execve.
+func (CheckerDefaults) Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error {
+ return nil
+}
+
+// ExitNotifyParent implements Checker.ExitNotifyParent.
+func (CheckerDefaults) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error {
+ return nil
+}
+
// CheckerReq indicates what checkpoints a corresponding Checker runs at, and
// what information it requires at those checkpoints.
type CheckerReq struct {
@@ -69,7 +83,9 @@ type CheckerReq struct {
// All of the following fields indicate what fields in the corresponding
// XInfo struct will be requested at the corresponding checkpoint.
- Clone CloneFields
+ Clone CloneFields
+ Execve ExecveFields
+ ExitNotifyParent ExitNotifyParentFields
}
// Global is the method receiver of all seccheck functions.
@@ -101,7 +117,9 @@ type state struct {
// corresponding XInfo struct have been requested by any registered
// checker, are accessed using atomic memory operations, and are mutated
// with registrationMu locked.
- cloneReq CloneFieldSet
+ cloneReq CloneFieldSet
+ execveReq ExecveFieldSet
+ exitNotifyParentReq ExitNotifyParentFieldSet
}
// AppendChecker registers the given Checker to execute at checkpoints. The
@@ -110,7 +128,11 @@ type state struct {
func (s *state) AppendChecker(c Checker, req *CheckerReq) {
s.registrationMu.Lock()
defer s.registrationMu.Unlock()
+
s.cloneReq.AddFieldsLoadable(req.Clone)
+ s.execveReq.AddFieldsLoadable(req.Execve)
+ s.exitNotifyParentReq.AddFieldsLoadable(req.ExitNotifyParent)
+
s.appendCheckerLocked(c)
for _, p := range req.Points {
word, bit := p/32, p%32
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index ea56f39c1..0f6e576a9 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID {
}
// Action implements stack.Target.Action.
-func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
index 82b6df18c..7b9c2a4db 100644
--- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
+++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
@@ -37,6 +37,8 @@ func (p *AtomicPtr) loadPtr(v *Value) {
// Load returns the value set by the most recent Store. It returns nil if there
// has been no previous call to Store.
+//
+//go:nosplit
func (p *AtomicPtr) Load() *Value {
return (*Value)(atomic.LoadPointer(&p.ptr))
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 25f5a52e3..dda473e48 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -426,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -466,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -549,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -576,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -717,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -744,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -841,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -940,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index dab99d00d..e2d2cf907 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1100,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -1183,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 068dab7ce..4fb7e9adb 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
+ if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
@@ -209,27 +215,38 @@ type bucket struct {
tuples tupleList
}
+func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
+ return tcpHeader, true
+ }
+ case header.UDPProtocolNumber:
+ if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
+ return udpHeader, true
+ }
+ }
+
+ return nil, false
+}
+
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
//
// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
netHeader := pkt.Network()
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return tupleID{}, &tcpip.ErrUnknownProtocol{}
}
return tupleID{
srcAddr: netHeader.SourceAddress(),
- srcPort: tcpHeader.SourcePort(),
+ srcPort: transportHeader.SourcePort(),
dstAddr: netHeader.DestinationAddress(),
- dstPort: tcpHeader.DestinationPort(),
- transProto: netHeader.TransportProtocol(),
+ dstPort: transportHeader.DestinationPort(),
+ transProto: pkt.TransportProtocolNumber,
netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
- // TODO(gvisor.dev/issue/6168): Support UDP.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return false
}
@@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
}
netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return false
- }
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
@@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
switch hook {
case Prerouting, Output:
- if conn.manip == manipDestination {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- case dirReply:
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
-
- updateSRCFields = true
- }
+ if conn.manip == manipDestination && dir == dirOriginal {
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
+ pkt.NatDone = true
+ } else if conn.manip == manipSource && dir == dirReply {
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
pkt.NatDone = true
}
case Input, Postrouting:
- if conn.manip == manipSource {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
-
- updateSRCFields = true
- case dirReply:
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
- }
+ if conn.manip == manipSource && dir == dirOriginal {
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+ updateSRCFields = true
+ pkt.NatDone = true
+ } else if conn.manip == manipDestination && dir == dirReply {
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+ updateSRCFields = true
pkt.NatDone = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
+
if !pkt.NatDone {
return false
}
@@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
fullChecksum := false
updatePseudoHeader := false
switch hook {
- case Prerouting, Input:
+ case Prerouting:
+ // Packet came from outside the stack so it must have a checksum set
+ // already.
+ fullChecksum = true
+ updatePseudoHeader = true
+ case Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
fullChecksum = true
@@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
rewritePacket(
netHeader,
- tcpHeader,
+ transportHeader,
updateSRCFields,
fullChecksum,
updatePseudoHeader,
@@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// Mark the connection as having been used recently so it isn't reaped.
conn.lastUsed = time.Now()
// Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ conn.updateLocked(pkt, hook)
return false
}
@@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
- // We only track TCP connections.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber, header.UDPProtocolNumber:
+ default:
+ // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable
+ // connections.
return
}
@@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ conn.updateLocked(pkt, hook)
ct.insertConn(conn)
}
@@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
srcPort: epID.LocalPort,
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
- transProto: header.TCPProtocolNumber,
+ transProto: transProto,
netProto: netProto,
}
conn, _ := ct.connForTID(tid)
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index f152c0d83..74c9075b4 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -264,12 +264,62 @@ const (
chainReturn
)
-// Check runs pkt through the rules for hook. It returns true when the packet
+// CheckPrerouting performs the prerouting hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
+ return it.check(Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
+}
+
+// CheckInput performs the input hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
+ return it.check(Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+}
+
+// CheckForward performs the forward hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
+ return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName)
+}
+
+// CheckOutput performs the output hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
+ return it.check(Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+}
+
+// CheckPostrouting performs the postrouting hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool {
+ return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+}
+
+// check runs pkt through the rules for hook. It returns true when the packet
// should continue traversing the network stack and false when it should be
// dropped.
//
-// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
return true
}
@@ -300,7 +350,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -311,7 +361,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v {
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v {
case RuleAccept:
continue
case RuleDrop:
@@ -375,19 +425,35 @@ func (it *IPTables) startReaper(interval time.Duration) {
}()
}
-// CheckPackets runs pkts through the rules for hook and returns a map of packets that
-// should not go forward.
+// CheckOutputPackets performs the output hook on the packets.
//
-// Preconditions:
-// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// * pkt.NetworkHeader is not nil.
+// Returns a map of packets that must be dropped.
+//
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return it.checkPackets(Output, pkts, r, outNicName)
+}
+
+// CheckPostroutingPackets performs the postrouting hook on the packets.
+//
+// Returns a map of packets that must be dropped.
+//
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return it.checkPackets(Postrouting, pkts, r, outNicName)
+}
+
+// checkPackets runs pkts through the rules for hook and returns a map of
+// packets that should not go forward.
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+//
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok {
+ if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -407,11 +473,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -428,7 +494,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -454,7 +520,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
@@ -477,16 +543,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr)
+ return rule.Target.Action(pkt, &it.connections, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, &tcpip.ErrNotConnected{}
}
- return it.connections.originalDst(epID, netProto)
+ return it.connections.originalDst(epID, netProto, transProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 96cc899bb..e8806ebdb 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,7 +79,7 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -97,7 +97,7 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -117,6 +117,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
+ var address tcpip.Address
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
@@ -125,7 +126,8 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
address = header.IPv6Loopback
}
case Prerouting:
- // No-op, as address is already set correctly.
+ // addressEP is expected to be set for the prerouting hook.
+ address = addressEP.MainAddress().Address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
@@ -180,7 +182,7 @@ type SNATTarget struct {
}
// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if st.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -206,34 +208,28 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- header.UDP(pkt.TransportHeader().View()),
- true, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- st.Port,
- st.Addr,
- )
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
+ port := st.Port
+
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ if port == 0 {
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ }
+ case header.TCPProtocolNumber:
+ if port == 0 {
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ }
}
+ }
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
+ // Set up conection for matching NAT rule. Only the first packet of the
+ // connection comes here. Other packets will be manipulated in connection
+ // tracking.
+ //
+ // Does nothing if the protocol does not support connection tracking.
+ if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil {
+ ct.handlePacket(pkt, hook, r)
}
return RuleAccept, 0
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 66e5f22ac..976194124 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -352,5 +352,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int)
+ Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index b9280c2de..bf248ef20 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -335,9 +335,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
// tell if a noop connection should be inserted at Input hook. Once conntrack
// redefines the manipulation field as mutable, we won't need the special noop
// connection.
- if pk.NatDone {
- newPk.NatDone = true
- }
+ newPk.NatDone = pk.NatDone
return newPk
}
@@ -347,7 +345,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
// The returned packet buffer will have the network and transport headers
// set if the original packet buffer did.
func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
- newPkt := NewPacketBuffer(PacketBufferOptions{
+ newPk := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: reservedHeaderBytes,
Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(),
IsForwardedPacket: true,
@@ -355,21 +353,28 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu
{
consumeBytes := pk.NetworkHeader().View().Size()
- if _, consumed := newPkt.NetworkHeader().Consume(consumeBytes); !consumed {
+ if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
}
- newPkt.NetworkProtocolNumber = pk.NetworkProtocolNumber
+ newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
{
consumeBytes := pk.TransportHeader().View().Size()
- if _, consumed := newPkt.TransportHeader().Consume(consumeBytes); !consumed {
+ if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
}
- newPkt.TransportProtocolNumber = pk.TransportProtocolNumber
+ newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
- return newPkt
+ // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
+ // maintain this flag in the packet. Currently conntrack needs this flag to
+ // tell if a noop connection should be inserted at Input hook. Once conntrack
+ // redefines the manipulation field as mutable, we won't need the special noop
+ // connection.
+ newPk.NatDone = pk.NatDone
+
+ return newPk
}
// headerInfo stores metadata about a header in a packet.
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index c5e896295..d45a2c05c 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1269,6 +1269,8 @@ type TransportProtocolNumber uint32
type NetworkProtocolNumber uint32
// A StatCounter keeps track of a statistic.
+//
+// +stateify savable
type StatCounter struct {
count atomicbitops.AlignedAtomicUint64
}
@@ -1995,6 +1997,8 @@ type Stats struct {
}
// ReceiveErrors collects packet receive errors within transport endpoint.
+//
+// +stateify savable
type ReceiveErrors struct {
// ReceiveBufferOverflow is the number of received packets dropped
// due to the receive buffer being full.
@@ -2012,8 +2016,10 @@ type ReceiveErrors struct {
ChecksumErrors StatCounter
}
-// SendErrors collects packet send errors within the transport layer for
-// an endpoint.
+// SendErrors collects packet send errors within the transport layer for an
+// endpoint.
+//
+// +stateify savable
type SendErrors struct {
// SendToNetworkFailed is the number of packets failed to be written to
// the network endpoint.
@@ -2024,6 +2030,8 @@ type SendErrors struct {
}
// ReadErrors collects segment read errors from an endpoint read call.
+//
+// +stateify savable
type ReadErrors struct {
// ReadClosed is the number of received packet drops because the endpoint
// was shutdown for read.
@@ -2039,6 +2047,8 @@ type ReadErrors struct {
}
// WriteErrors collects packet write errors from an endpoint write call.
+//
+// +stateify savable
type WriteErrors struct {
// WriteClosed is the number of packet drops because the endpoint
// was shutdown for write.
@@ -2054,6 +2064,8 @@ type WriteErrors struct {
}
// TransportEndpointStats collects statistics about the endpoint.
+//
+// +stateify savable
type TransportEndpointStats struct {
// PacketsReceived is the number of successful packet receives.
PacketsReceived StatCounter
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 181ef799e..7c998eaae 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -34,12 +34,16 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
"//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 28b49c6be..bdf4a64b9 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -15,19 +15,24 @@
package iptables_test
import (
+ "bytes"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type inputIfNameMatcher struct {
@@ -1156,3 +1161,286 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
})
}
}
+
+func TestSNAT(t *testing.T) {
+ const listenPort = 8080
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+
+ nattedClientAddr tcpip.Address
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+ t.Cleanup(ep.Close)
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ipt := routerStack.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err)
+ }
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ {
+ name: "IPv6 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ipt := routerStack.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err)
+ }
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ expectedConnectErr tcpip.Error
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ needRemoteAddr bool
+ }{
+ {
+ name: "UDP",
+ proto: udp.ProtocolNumber,
+ expectedConnectErr: nil,
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ if err := ep.Connect(clientAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
+ }
+ return nil, nil
+ },
+ needRemoteAddr: true,
+ },
+ {
+ name: "TCP",
+ proto: tcp.ProtocolNumber,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
+ t.Helper()
+
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("ep.Listen(1): %s", err)
+ }
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ var addr tcpip.FullAddress
+ for {
+ newEP, wq, err := ep.Accept(&addr)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Accept(_): %s", err)
+ }
+ if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
+ "NIC",
+ )); diff != "" {
+ t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
+ }
+
+ we, newCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ return newEP, newCH
+ }
+ },
+ needRemoteAddr: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(serverAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
+ }
+ }
+ nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ nattedClientAddr.Port = addr.Port
+ }
+
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
+ }
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, nattedClientAddr)
+ }
+
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 947bcc7b1..c69410859 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -40,6 +40,14 @@ const (
Host2NICID = 4
)
+// Common NIC names used by tests.
+const (
+ Host1NICName = "host1NIC"
+ RouterNIC1Name = "routerNIC1"
+ RouterNIC2Name = "routerNIC2"
+ Host2NICName = "host2NIC"
+)
+
// Common link addresses used by tests.
const (
LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
@@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
- if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host1NICName}
+ if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil {
+ t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC1Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC2Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err)
+ }
}
- if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host2NICName}
+ if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil {
+ t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err)
+ }
}
if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index bbc0e3ecc..4718ec4ec 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 1e519085d..bb0db9f70 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,7 @@
package icmp
import (
+ "fmt"
"io"
"time"
@@ -24,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -35,15 +38,6 @@ type icmpPacket struct {
receivedAt time.Time `state:".(int64)"`
}
-type endpointState int
-
-const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
-)
-
// endpoint represents an ICMP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -51,14 +45,17 @@ const (
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -70,38 +67,23 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
- state endpointState
- route *stack.Route `state:"manual"`
- ttl uint8
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
+ ident uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
- state: stateInitial,
uniqueID: s.UniqueID(),
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetSendBufferSize(32*1024, false /* notify */)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ ep.net.Init(s, netProto, transProto, &ep.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -128,35 +110,40 @@ func (e *endpoint) Abort() {
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
- e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
- case stateBound, stateConnected:
- bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
- }
+ notify := func() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ return false
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice()))
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
- // Close the receive list and drain it.
- e.rcvMu.Lock()
- e.rcvClosed = true
- e.rcvBufSize = 0
- for !e.rcvList.Empty() {
- p := e.rcvList.Front()
- e.rcvList.Remove(p)
- }
- e.rcvMu.Unlock()
+ e.net.Shutdown()
+ e.net.Close()
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- // Update the state.
- e.state = stateClosed
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
- e.mu.Unlock()
+ return true
+ }()
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ if notify {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ }
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
@@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {}
// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -214,13 +201,12 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
//
// Returns true for retry if preparation should be retried.
// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.state {
- case stateInitial:
- case stateConnected:
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
-
- case stateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
- return 0, err
+ return network.WriteContext{}, 0, err
}
if !retry {
@@ -298,36 +272,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
}
- route := e.route
- if to != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := to.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
- dst, netProto, err := e.checkV4MappedLocked(*to)
- if err != nil {
- return 0, err
- }
-
- // Find the endpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return 0, err
- }
- defer r.Release()
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ return ctx, e.ident, err
+}
- route = r
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ ctx, ident, err := e.prepareForWrite(opts)
+ if err != nil {
+ return 0, err
}
+ defer ctx.Release()
// TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
@@ -335,17 +289,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return 0, &tcpip.ErrBadBuffer{}
}
- var err tcpip.Error
- switch e.NetProto {
+ switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
+ if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
case header.IPv6ProtocolNumber:
- err = send6(route, e.ID.LocalPort, v, e.ttl)
- }
-
- if err != nil {
- return 0, err
+ if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
+ default:
+ panic(fmt.Sprintf("unhandled network protocol = %d", netProto))
}
return int64(len(v)), nil
@@ -358,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool {
return e.stack.HasNIC(tcpip.NICID(id))
}
-// SetSockOpt sets a socket option.
-func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
- return nil
+// SetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ return e.net.SetSockOpt(opt)
}
-// SetSockOptInt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- }
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -388,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.rcvMu.Lock()
- v := int(e.ttl)
- e.rcvMu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+// GetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error {
+func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength),
})
- pkt.Owner = owner
icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
@@ -427,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest
-
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
-
pkt.Data().AppendView(data)
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V4.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
+func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength),
})
icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
@@ -469,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest
pkt.Data().AppendView(data)
dataRange := pkt.Data().AsRange()
icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpv6,
- Src: r.LocalAddress(),
- Dst: r.RemoteAddress(),
+ Src: src,
+ Dst: dst,
PayloadCsum: dataRange.Checksum(),
PayloadLen: dataRange.Size(),
}))
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V6.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
+ return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -516,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- localPort := uint16(0)
- switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.ident
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ nextID, err := e.registerWithStack(netProto, nextID)
+ if err != nil {
+ return err
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress(),
- LocalPort: localPort,
- RemoteAddress: r.RemoteAddress(),
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- id, err = e.registerWithStack(nicID, netProtos, id)
+ e.ident = nextID.LocalPort
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- e.ID = id
- e.route = r
- e.RegisterNICID = nicID
-
- e.state = stateConnected
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -586,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- e.shutdownFlags |= flags
- if e.state != stateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
}
if flags&tcpip.ShutdownRead != 0 {
@@ -616,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
- return id, err
+ return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
+ err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
@@ -645,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err := e.registerWithStack(boundNetProto, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ e.ident = id.LocalPort
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.RegisterNICID = addr.NIC
-
- // Mark endpoint as bound.
- e.state = stateBound
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -692,7 +571,7 @@ func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address)
return addr == header.IPv4Broadcast ||
header.IsV4MulticastAddress(addr) ||
header.IsV6MulticastAddress(addr) ||
- e.stack.IsSubnetBroadcast(nicID, e.NetProto, addr)
+ e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr)
}
// Bind binds the endpoint to a specific local address and port.
@@ -705,15 +584,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- err := e.bindLocked(addr)
- if err != nil {
- return err
- }
-
- e.BindNICID = addr.NIC
- e.BindAddr = addr.Addr
-
- return nil
+ return e.bindLocked(addr)
}
// GetLocalAddress returns the address to which the endpoint is bound.
@@ -721,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.ident
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -733,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
- return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
+ if addr, connected := e.net.GetRemoteAddress(); connected {
+ return addr, nil
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -766,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
- switch e.NetProto {
+ switch e.net.NetProto() {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
@@ -840,9 +705,9 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ defer e.mu.RUnlock()
+ ret := e.net.Info()
+ ret.ID.LocalPort = e.ident
return &ret
}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index b8b839e4a..dfe453ff9 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,11 +15,13 @@
package icmp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.thaw()
+
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- var err tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ info.ID, err = e.registerWithStack(info.NetProto, info.ID)
if err != nil {
- panic(err)
+ panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err))
}
-
- e.ID.LocalAddress = e.route.LocalAddress()
- } else if len(e.ID.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
- if err != nil {
- panic(err)
+ e.ident = info.ID.LocalPort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
index b1edce39b..3818cb04e 100644
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -9,6 +9,7 @@ go_library(
"endpoint_state.go",
],
visibility = [
+ "//pkg/tcpip/transport/icmp:__pkg__",
"//pkg/tcpip/transport/raw:__pkg__",
"//pkg/tcpip/transport/udp:__pkg__",
],
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index e4a64e191..689427d53 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -67,7 +67,7 @@ type endpoint struct {
waiterQueue *waiter.Queue
cooked bool
ops tcpip.SocketOptions
- stats tcpip.TransportEndpointStats `state:"nosave"`
+ stats tcpip.TransportEndpointStats
// The following fields are used to manage the receive queue.
rcvMu sync.Mutex `state:"nosave"`
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 3040a445b..bfef75da7 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -70,7 +70,7 @@ type endpoint struct {
associated bool
net network.Endpoint
- stats tcpip.TransportEndpointStats `state:"nosave"`
+ stats tcpip.TransportEndpointStats
ops tcpip.SocketOptions
// The following fields are used to manage the receive queue and are
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 5148fe157..20958d882 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -80,9 +80,10 @@ go_library(
go_test(
name = "tcp_x_test",
- size = "medium",
+ size = "large",
srcs = [
"dual_stack_test.go",
+ "rcv_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
"tcp_rack_test.go",
@@ -114,16 +115,6 @@ go_test(
)
go_test(
- name = "rcv_test",
- size = "small",
- srcs = ["rcv_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
name = "tcp_test",
size = "small",
srcs = [
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index ff0a5df9c..7115d0a12 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -193,14 +193,6 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
-func (l *listenContext) useSynCookies() bool {
- var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
- if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
- panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
- }
- return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull())
-}
-
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
@@ -277,7 +269,7 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
- l.listenEP.propagateInheritableOptionsLocked(ep)
+ l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce
if !ep.reserveTupleLocked() {
ep.mu.Unlock()
@@ -367,7 +359,6 @@ func (l *listenContext) closeAllPendingEndpoints() {
l.pending.Wait()
}
-// Precondition: h.ep.mu must be held.
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
@@ -384,7 +375,7 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// cleanupCompletedHandshake transfers any state from the completed handshake to
// the new endpoint.
//
-// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
if l.listenEP != nil {
@@ -404,7 +395,8 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
//
-// Precondition: e.mu and n.mu must be held.
+// +checklocks:e.mu
+// +checklocks:n.mu
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
n.portFlags = e.portFlags
@@ -415,9 +407,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// reserveTupleLocked reserves an accepted endpoint's tuple.
//
-// Preconditions:
-// * propagateInheritableOptionsLocked has been called.
-// * e.mu is held.
+// Precondition: e.propagateInheritableOptionsLocked has been called.
+//
+// +checklocks:e.mu
func (e *endpoint) reserveTupleLocked() bool {
dest := tcpip.FullAddress{
Addr: e.TransportEndpointInfo.ID.RemoteAddress,
@@ -459,7 +451,7 @@ func (e *endpoint) notifyAborted() {
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+// +checklocks:e.mu
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error {
defer s.decRef()
@@ -552,7 +544,7 @@ func (a *accepted) acceptQueueIsFullLocked() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+// +checklocks:e.mu
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
e.rcvQueueInfo.rcvQueueMu.Lock()
rcvClosed := e.rcvQueueInfo.RcvClosed
@@ -579,8 +571,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
}
+ alwaysUseSynCookies := func() bool {
+ var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
+ panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
+ }
+ return bool(alwaysUseSynCookies)
+ }()
+
opts := parseSynSegmentOptions(s)
- if !ctx.useSynCookies() {
+ if !alwaysUseSynCookies && !e.synRcvdBacklogFull() {
s.incRef()
atomic.AddInt32(&e.synRcvdCount, 1)
return e.handleSynSegment(ctx, s, opts)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a3002abf3..407ab2664 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2066,7 +2066,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber)
e.UnlockUser()
if err != nil {
return err
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
index 8a026ec46..e47a07030 100644
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rcv_test
+package tcp_test
import (
"testing"
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index b355fa7eb..049957b81 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -60,9 +60,8 @@ type endpoint struct {
waiterQueue *waiter.Queue
uniqueID uint64
net network.Endpoint
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats tcpip.TransportEndpointStats `state:"nosave"`
- ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.